<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 6: Automatic  Differentiation (Solutions 51-60) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* Sample functions `cube` and `areas` will be used for the exercises.

In [1]:
!python3 -m pip install jax



In [2]:
import jax
import jax.numpy as jnp
import math
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
## sample cube function
def cube(x):
    return x**3

cube(2)

8

In [4]:
## sample areas function
def areas(x):
    return [math.sqrt(3)*x**2/4, x**2, math.pi*x**2]

areas(2)

[1.7320508075688772, 4, 12.566370614359172]

**Exercise 51: JIT-compile the derivative of `cube` and assign it to `derivative_cube`**

In [5]:
derivative_cube = jax.jit(jax.grad(cube))
derivative_cube

<CompiledFunction at 0x7f416e57d8c0>

**Exercise 52: Run `derivative_cube` with value=7**

In [6]:
derivative_cube(7.0)

DeviceArray(147., dtype=float32)

**Exercise 53: JIT-compile the value and derivative of `cube` together, assign it to `value_and_derivative_cube` and run it with value=7**

In [7]:
value_and_derivative_cube = jax.jit(jax.value_and_grad(cube))
value_and_derivative_cube(7.0)

(DeviceArray(343., dtype=float32), DeviceArray(147., dtype=float32))

**Exercise 54: JIT-compile the second order derivative of `cube`, assign it to `derivative_cube_2` and run it with value=7**

In [8]:
derivative_cube_2 = jax.jit(jax.grad(jax.grad(cube)))
derivative_cube_2(7.0)

DeviceArray(42., dtype=float32)

**Exercise 55: JIT-compile the hessian of `cube`, assign it to `hessian_cube` and run it with value=7**

In [9]:
hessian_cube = jax.jit(jax.hessian(cube))
hessian_cube(7.0)

DeviceArray(42., dtype=float32)

**Exercise 56: JIT-compile `areas`, assign it to `jit_areas` and run it with value=9**

In [10]:
jit_areas = jax.jit(areas)
jit_areas(9)

[DeviceArray(35.074028, dtype=float32),
 DeviceArray(81, dtype=int32),
 DeviceArray(254.46901, dtype=float32)]

**Exercise 57: Compute the Jacobian of `areas` using forward-mode automatic differentiation, assign it to `jacfwd_areas` and run it with value=9**

In [11]:
jacfwd_areas = jax.jacfwd(areas)
jacfwd_areas(9.0)

[DeviceArray(7.7942286, dtype=float32),
 DeviceArray(18., dtype=float32),
 DeviceArray(56.548668, dtype=float32)]

**Exercise 58: Compute the Jacobian of `areas` using reverse-mode automatic differentiation, assign it to `jacrev_areas` and run it with value=9**

In [12]:
jacrev_areas = jax.jacrev(areas)
jacrev_areas(9.0)

[DeviceArray(7.7942286, dtype=float32),
 DeviceArray(18., dtype=float32),
 DeviceArray(56.548668, dtype=float32)]

**Exercise 59: Compute the Jacobian-vector product of `cube` at value=7 with vector=9 and assign it to `jvp_cube`**

In [13]:
jvp_cube = jax.jvp(cube, (7.0,), (9.0,))
jvp_cube

(DeviceArray(343., dtype=float32), DeviceArray(1323., dtype=float32))

**Exercise 60: Compute the linear approximation of `areas` with value=5, assign it to `areas_linear` and run it with value=9**

In [14]:
_, areas_linear = jax.linearize(areas, 5.0)
areas_linear(9.0)

[DeviceArray(38.97114, dtype=float32),
 DeviceArray(90., dtype=float32),
 DeviceArray(282.74335, dtype=float32)]

<center>
    This completes Set 6: Automatic Differentiation (Solutions 51-60) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>