<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 4: Just-In-Time (JIT) Compilation (Solutions 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample function `cube` will be used for the exercises.

In [1]:
!python3 -m pip install jax



In [2]:
import jax
import jax.numpy as jnp
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
def cube(x):
    return x**3

cube(2.1)

9.261000000000001

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

In [4]:
cube_jit = jax.jit(cube)
cube_jit

<CompiledFunction at 0x7fd5248a85f0>

**Exercise 32: Display execution time of `cube_jit` for first run (with overhead) with input=10.24**

In [5]:
%%time
cube_jit(10.24)

CPU times: user 221 ms, sys: 149 ms, total: 370 ms
Wall time: 366 ms


DeviceArray(1073.7418, dtype=float32)

**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

In [6]:
%%time
cube_jit(10.24)

CPU times: user 16.8 ms, sys: 20.5 ms, total: 37.2 ms
Wall time: 37.4 ms


DeviceArray(1073.7418, dtype=float32)

**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

In [7]:
cube_value = cube_jit(10.24)
cube_value

DeviceArray(1073.7418, dtype=float32)

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

In [8]:
with jax.disable_jit():
    cube_value_nojit = cube_jit(10.24)

cube_value_nojit

1073.7418240000002

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

In [9]:
cube_shape = jax.eval_shape(cube_jit, 10.24)
cube_shape

ShapeDtypeStruct(shape=(), dtype=float32)

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

In [10]:
cube_jaxpr = jax.make_jaxpr(cube_jit)(10.24)
cube_jaxpr

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = integer_pow[ y=3 ] a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=cube ] a
  in (b,) }

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

In [11]:
cube_xla = jax.xla_computation(cube_jit)(10.24)
print(cube_xla.as_hlo_text())

HloModule xla_computation_cube.12

jit_cube__1.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  multiply.7 = f32[] multiply(parameter.4, multiply.6)
  ROOT tuple.8 = (f32[]) tuple(multiply.7)
}

ENTRY xla_computation_cube.12 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.9 = (f32[]) call(parameter.1), to_apply=jit_cube__1.3
  get-tuple-element.10 = f32[] get-tuple-element(call.9), index=0
  ROOT tuple.11 = (f32[]) tuple(get-tuple-element.10)
}




**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

In [12]:
cube_named_jit = jax.named_call(cube_jit, name='jaxton_cube_fn')
cube_named_jit

<function __main__.cube(x)>

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

In [13]:
cube_named_xla = jax.xla_computation(cube_named_jit)(10.24)
print(cube_named_xla.as_hlo_text())

HloModule xla_computation_cube__1.18

jit_cube__2.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  multiply.7 = f32[] multiply(parameter.4, multiply.6)
  ROOT tuple.8 = (f32[]) tuple(multiply.7)
}

jaxton_cube_fn.9 {
  constant.11 = pred[] constant(false)
  parameter.10 = f32[] parameter(0)
  call.12 = (f32[]) call(parameter.10), to_apply=jit_cube__2.3
  get-tuple-element.13 = f32[] get-tuple-element(call.12), index=0
  ROOT tuple.14 = (f32[]) tuple(get-tuple-element.13)
}

ENTRY xla_computation_cube__1.18 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.15 = (f32[]) call(parameter.1), to_apply=jaxton_cube_fn.9
  get-tuple-element.16 = f32[] get-tuple-element(call.15), index=0
  ROOT tuple.17 = (f32[]) tuple(get-tuple-element.16)
}




<center>
    This completes Set 4: Just-In-Time (JIT) Compilation (Solutions 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>