<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample function `cube` will be used for the exercises.

In [1]:
!python3 -m pip install jax



In [2]:
import jax
import jax.numpy as jnp
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

2022-05-09 17:20:49.933472: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:200] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 10501554176


[CpuDevice(id=0)]

In [3]:
def cube(x):
    return x**3

cube(2.1)

9.261000000000001

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

In [7]:
cube_jit = jax.jit(cube)
cube_jit(2.1)

DeviceArray(9.260999, dtype=float32, weak_type=True)

**Exercise 32: Display execution time of `cube_jit` for first run (with overhead) with input=10.24**

In [8]:
%timeit cube(10.24)
%timeit cube_jit(10.24)

77.3 ns ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
2.41 µs ± 26.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

In [10]:
%timeit cube(10.24)
%timeit cube_jit(10.24)

76.2 ns ± 0.448 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
2.39 µs ± 7.69 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

In [11]:
cube_value = cube_jit(10.24)
cube_value

DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

In [13]:
with jax.disable_jit():
    cube_value_nojit = cube_jit(10.24)
cube_value_nojit

1073.7418240000002

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

In [15]:
cube_shape = cube_jit(10.24).shape
cube_shape

()

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

In [None]:
cube_jaxpr = cube_jit(10.24).jaxpr
# TODO: jaxpr

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

<center>
    This completes Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>