<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample function `cube` will be used for the exercises.

In [None]:
!python3 -m pip install jax

In [1]:
import jax
import jax.numpy as jnp
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()



[CpuDevice(id=0)]

In [2]:
def cube(x):
    return x**3

cube(2.1)

9.261000000000001

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

In [19]:
jit_cube = jax.jit(cube)

**Exercise 32: Display execution time of `cube_jit` for first run (with overhead) with input=10.24**

In [20]:
%%time
jit_cube(10.24)

CPU times: user 1.32 ms, sys: 0 ns, total: 1.32 ms
Wall time: 1.28 ms


DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

In [21]:
%%time
jit_cube(10.24)

CPU times: user 111 µs, sys: 23 µs, total: 134 µs
Wall time: 198 µs


DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

In [25]:
cube_value = jit_cube(10.24)
cube_value

DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

In [24]:
with jax.disable_jit():
    cube_value_nojit = jit_cube(10.24)
cube_value_nojit

1073.7418240000002

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

In [33]:
jit_cube(10.24)

DeviceArray(1073.7418, dtype=float32, weak_type=True)

In [30]:
jit_cube(10.24).shape

()

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

In [35]:
jax.make_jaxpr(jit_cube)(10.24)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; c[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m d[35m:f32[][39m = integer_pow[y=3] c [34m[22m[1min [39m[22m[22m(d,) }
      name=cube
    ] a
  [34m[22m[1min [39m[22m[22m(b,) }

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

In [38]:
print(jax.xla_computation(jit_cube)(10.24).as_hlo_text())

HloModule xla_computation_cube, entry_computation_layout={(f32[])->(f32[])}

cube.2 {
  Arg_0.3 = f32[] parameter(0)
  multiply.4 = f32[] multiply(Arg_0.3, Arg_0.3)
  ROOT multiply.5 = f32[] multiply(Arg_0.3, multiply.4)
}

ENTRY main.8 {
  Arg_0.1 = f32[] parameter(0)
  call.6 = f32[] call(Arg_0.1), to_apply=cube.2
  ROOT tuple.7 = (f32[]) tuple(call.6)
}




**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

In [44]:
cube_named_jit = jax.named_call(jit_cube, name='jaxton_cube_fn')

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

In [48]:
cube_named_xla = jax.xla_computation(cube_named_jit)(10.24)
print(cube_named_xla.as_hlo_text())

HloModule xla_computation_cube, entry_computation_layout={(f32[])->(f32[])}

cube.2 {
  Arg_0.3 = f32[] parameter(0)
  multiply.4 = f32[] multiply(Arg_0.3, Arg_0.3)
  ROOT multiply.5 = f32[] multiply(Arg_0.3, multiply.4)
}

ENTRY main.8 {
  Arg_0.1 = f32[] parameter(0)
  call.6 = f32[] call(Arg_0.1), to_apply=cube.2
  ROOT tuple.7 = (f32[]) tuple(call.6)
}




<center>
    This completes Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>