In [None]:
import subprocess
import os

try:
    subprocess.check_output('nvidia-smi')
    print("a GPU is connected.")
except Exception:
    # TPU or CPU
    if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
      print("A TPU is connected.")
      import jax.tools.colab_tpu
      jax.tools.colab_tpu.setup_tpu()
    else:
      print("Only CPU accelerator is connected.")
      # x8 cpu devices - number of (emulated) host devices
      os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap

import matplotlib.pyplot as plt
import numpy as np

In [None]:
import copy
from typing import Dict


def plot_performance(data: Dict, title: str) -> None:
    runs = list(data.keys())
    time = list(data.values())


    plt.bar(runs, time, width = 0.4)

    plt.xlabel("Implementation")
    plt.ylabel("Average time taken (s)")
    plt.title(title)
    plt.show()

    best_perf_key = min(data, key = data.get)
    all_runs_key = copy.copy(runs)


    all_runs_key.remove(best_perf_key)

    for k in all_runs_key:
        print(
            f"{best_perf_key} was {round(data[best_perf_key] / data[k], 2)}x faster than {k}"
        )




In [None]:
print(f"Num devices: {jax.device_count()}")
print(f" Devices: {jax.devices()}")


JAX = Autograd + XLA (accelerated linear algebra)

* automatic differentation (grad)
* parallelization (pmap)
* vectorization (vmap)
* just-in-time compilation (jit)

* XLA allows for accelerator agnostic computing

In [None]:
# numpy version
x = np.linspace(-np.pi, np.pi, 100)

y = np.sin(x)

plt.plot(x,y, "b", label = "y = sin (x)")
plt.legend(loc = "best")
plt.show()

In [None]:
# jax version
x = jnp.linspace(-jnp.pi, jnp.pi, 100)

y = jnp.sin(x)

plt.plot(x,y, "b", label = "y = sin (x)")
plt.legend(loc = "best")
plt.show()

x = jnp.linspace(-jnp.pi, jnp.pi, 100)
y = jnp.cos(x)

plt.plot(x,y, "r", label = "y = cos (x)")
plt.legend(loc = "best")
plt.show()

What are the differences between JAX and NumPy
* JAX arrays are immmutable
* Jax handles randomness EXPLICITLY


In [None]:
# Example in NumPy
# mutuable arrays
x= np.arange(10)
x[0] = 10
print(x)
#

In [None]:
# JAX issue
try:
    x = jnp.arange(10)
    x[0] = 10
except Exception as e:
    print("Exception {}".format(e))


In [None]:
# solution


x = jnp.arange(10)
new_x = x.at[0].set(10) # new_x is now a copy of the original x
print(f" new_x: {new_x} original x: {x}")

In [None]:
# Randomness in JAX

np.random.seed(42)

prng_state = np.random.get_state()

def is_prng_state_the_same(prng_1, prng_2):
    """Helper function to compare two prng keys."""
    # concat all elements in prng tuple
    list_prng_data_equal = [(a == b) for a, b in zip(prng_1, prng_2)]
    # stack all elements together
    list_prng_data_equal = np.hstack(list_prng_data_equal)
    # check if all elements are the same
    is_prng_equal = all(list_prng_data_equal)
    return is_prng_equal

print(
    f"sample 1 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}"
)
prng_state = np.random.get_state()
print(
    f"sample 2 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}"
)
prng_state = np.random.get_state()
print(
    f"sample 3 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}"
)

### Numpy's global random state is updated every time a new random num is generated.... We don't want this because we want to handle randomness ina  REPRODUCIBLE way across different threads/processes/devices.



In [None]:
from jax import random

key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")
print(f"sample 2 = {random.normal(key)}")
print(f"sample 3 = {random.normal(key)}")

In [None]:

from jax import random

key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")

# We split the key -> new key and subkey
new_key, subkey = random.split(key)

# We use the subkey immediately and keep the new key for future splits.
# It doesn't really matter which key we keep and which one we use immediately.
print(f"sample 2 = {random.normal(subkey)}")

# We split the new key -> new key2 and subkey
new_key2, subkey = random.split(new_key)
print(f"sample 3 = {random.normal(subkey)}")

In [None]:
# Calculating a dot product in Numpy on CPU

size = 1000
x = np.random.normal(size = (size, size))
y = np.random.normal(size = (size, size))

numpy_time = %timeit -o -n 10 a_np = np.dot(x,y)



In [None]:

size = 1000
key1, key2 = jax.random.split(jax.random.PRNGKey(42), num=2)
x = jax.random.normal(key1, shape=(size, size))
y = jax.random.normal(key2, shape=(size, size))
jax_time = %timeit -o -n 10 jnp.dot(y, x.T).block_until_ready()





JAX Transformations


In [None]:
# jit

def relu(x):
    return jnp.maximum(0, x)

relu_jit = jax.jit(relu)

print(relu_jit(jnp.array([-1, 0, 1])))


In [None]:
# grad

f = lambda x: 6* x**5 - 4*x**3 + 2*x**2 - 1

dfdx = jax.grad(f)
print(dfdx(2.))

In [None]:
# vmap

def min_max(x):
    return jnp.array([jnp.min(x), jnp.max(x)])

batch_size = 3
batched_x = np.arange(15).reshape((batch_size, -1))
print(batched_x)

print(min_max(batched_x))

def vmap_min_max(x):
    return vmap(min_max)(x)

print(vmap_min_max(batched_x))



In [None]:
@jit
def manual_batch_min_max_loop(batched_x):
    min_max_result_list = []
    for x in batched_x:
        min_max_result_list.append(min_max(x))
    return jnp.array(min_max_result_list)

print(manual_batch_min_max_loop(batched_x))


@jit
def manual_batch_min_max_axis(batched_x):
    return jnp.stack([jnp.min(batched_x, axis = 1), jnp.max(batched_x, axis = 1)])

print(manual_batch_min_max_axis(batched_x))

In [None]:
# vmap

@jit
def min_max_vmap(batched_x):
    return vmap(min_max)(batched_x)

## We add extra dimention in a single vector, shape changes from (5,) to (1,5), which makes the vmapping possible
x_with_extra_dim = jax.numpy.expand_dims(batched_x, axis = 0)
print(min_max_vmap(x_with_extra_dim))


In [None]:
# pmap

### TO DO

## Lin Regression

In [None]:


# simple toy dataset



x_data_list = [210, 160, 240, 140, 300]
y_data_list = [4, 3.3, 3.7, 2.3, 5.4]


def loss_function(b,w):
    f = w*x + b
    errors = jnp.square(y-f)

    return 1 / 2 * jnp.mean(errors)


auto_grad = jax.grad(loss_function, argnums = (0,1))




def manual_grad(b, w):
    grad_b = 0
    grad_w = 0
    for x, y in zip(x_data_list, y_data_list):
        f = w * x + b
        grad_b += f - y
        grad_w += (f - y) * x
    grad_b /= len(x_data_list)
    grad_w /= len(x_data_list)
    return grad_b, grad_w

b, w = 2.5, 3.5
grad_b_autograd, grad_w_autograd = auto_grad(b, w)
print("Autograd         grad_b:", grad_b_autograd, "  grad_w", grad_w_autograd)

grad_b_manual, grad_w_manual = manual_grad(b, w)
print("Manual gradients grad_b:", grad_b_manual, "  grad_w", grad_w_manual)
