# JAX Notes

## Contents

**1. You don't know JAX**<br>
 - Basic functionality of JAX
 - Basic Neural Net application
 
**2. Quickstart**<br>
 - Basic functionality with little more under the hood look
 
**3. Neural Network and Data-loading**

# 1. You don't know JAX

https://colinraffel.com/blog/you-don-t-know-jax.html

Jax is numpy made to interface with specialized parallel computing hardware for ML purposes

So what does it need the numpy doesn't have?<br>
    1. JIT - Just in time compilation --> Optimized compilation to enhance speed. Makes function computation faster.<br>
    2. Gradients - Some way to enhance methods of producing gradients of any function outputs wrt chosen params.<br>
    3. Vectorization - Batching of inputs into function to be executed in parallel.<br>
    
JAX calls the above modifications to numpy as 'transformations'. They are implemented as functions, namely, jit(), grad() and vmap()/pmap()

In [1]:
import random # to generate random numbers
import itertools # to count

import jax
import jax.numpy as np #allows previous numpy code to be directly interchanged with JAX
import numpy as onp
#NOTE: The convention of importing JAX is still divisive -->
# JAX centric new codes import jax.numpy as jnp and numpy as np
# But old codes with numpy that are now expected to work on jax just edit the numpy to jax.numpy
# onp is for original numpy

import time
from __future__ import print_function

Chosen example: One of the core problems of AGI is learning the Exclusive OR (XOR) function with a neural network. Based on the book Perceptron by Minsky and Papert, who claimed XOR was not solvable with 2 layer feedforward layer.



### Net & Loss

In [4]:
# Sigmoid nonlinearity - usual code
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output - w1
def net(params, x):
    w1, b1, w2, b2 = params #w1 - 3x2, b1 - 3x1, w2 - 3x1, b2 - 1x1
    hidden = np.tanh(np.dot(w1, x) + b1) #output of hidden layer
    return sigmoid(np.dot(w2, hidden) + b2) #output layer

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

### Training - no batches, one sample at a time, without JIT

In [5]:
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

loss_grad = jax.grad(loss) #loss is a function, loss_grad is also a function

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

startfull = time.time()
for n in itertools.count(): #just counts, consider using while?
    start = time.time()
    # Grab a single random input, notice x is jax array
    x = inputs[onp.random.choice(inputs.shape[0])] #randomly gives integer in range.  
    #True stochastic gradient descent. Since only one sample per cycle
    # Compute the target output
    y = onp.bitwise_xor(*x)
    
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y) #autogradient for your function, but again, only one input at a time
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 10:
        end = time.time()
        print('Iteration {}'.format(n),'time taken:', end-start,'s')
        if test_all_inputs(inputs, params):
            break
endfull = time.time()
print('_'*100)
print('Total time taken:', endfull-startfull,'s')

Iteration 0 time taken: 0.009472131729125977 s
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 10 time taken: 0.005346059799194336 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 20 time taken: 0.005450010299682617 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 30 time taken: 0.005522966384887695 s
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 40 time taken: 0.0057489871978759766 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 50 time taken: 0.0054090023040771484 s
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
____________________________________________________________________________________________________
Total time taken: 0.29266786575317383 s


### Training - no batches, one sample at a time, with JIT

In [6]:

loss_grad = jax.jit(jax.grad(loss))#loss is a function, loss_grad is also a function

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

startfull = time.time()
for n in itertools.count(): #just counts, consider using while?
    start = time.time()
    # Grab a single random input, notice x is jax array
    x = inputs[onp.random.choice(inputs.shape[0])] #randomly gives integer in range.  
    #True stochastic gradient descent. Since only one sample per cycle
    # Compute the target output
    y = onp.bitwise_xor(*x)
    
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y) #autogradient for your function, but again, only one input at a time
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 10:
        end = time.time()
        print('Iteration {}'.format(n),'time taken:', end-start,'s')
        if test_all_inputs(inputs, params):
            break
endfull = time.time()
print('_'*100)
print('Total time taken:', endfull-startfull,'s')

Iteration 0 time taken: 0.034810781478881836 s
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 1
Iteration 10 time taken: 0.00014781951904296875 s
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 1
Iteration 20 time taken: 0.0001499652862548828 s
[0 0] -> 1
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 30 time taken: 0.0001761913299560547 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 40 time taken: 0.0001418590545654297 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 50 time taken: 0.0001399517059326172 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 60 time taken: 0.00013780593872070312 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 70 time taken: 0.00013899803161621094 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 80 time taken: 0.00013589859008789062 s
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 90 time taken: 0.00013113021850585938 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 100 time taken: 0.0001347064971923828 s
[0 0] -> 0
[0 1] 

### Training - batching and JIT - mini-batch gradient descent

In [7]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))
# --------Just in time--parallel----batching over?-params, x[0], y[0]---O[0] --> 0th axis batching
params = initial_params()

batch_size = 100

startfull = time.time()
for n in itertools.count():
    start = time.time()
    
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)] #(batch_size (parallelization along here), feature_length) = (100,2)
    y = onp.bitwise_xor(x[:, 0], x[:, 1]) #only one axis
    
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)] #only difference is mean gradient for mini-batch
    if not n % 10:
        end = time.time()
        print('Iteration {}'.format(n),'time taken:', end-start,'s')
        if test_all_inputs(inputs, params):
            break
endfull = time.time()
print('_'*100)
print('Total time taken:', endfull-startfull,'s')

Iteration 0 time taken: 0.21003389358520508 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 10 time taken: 0.0005619525909423828 s
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 20 time taken: 0.0004937648773193359 s
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 30 time taken: 0.0005571842193603516 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 40 time taken: 0.0004856586456298828 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 50 time taken: 0.0005261898040771484 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 60 time taken: 0.0004711151123046875 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 70 time taken: 0.00048804283142089844 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 80 time taken: 0.00046825408935546875 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 90 time taken: 0.0004642009735107422 s
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 100 time taken: 0.0004730224609375 s
[0 0] -> 0
[0 1] -> 0
[1

Batching allows more stable learning, lesser chance of divergence or incorrect minima. As seen, vectorization did not increase the per iteration time much.

# 2. Quickstart

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

In [8]:
import jax.numpy as jnp #another convention, if making jax based codebase firsthand
from jax import grad, jit, vmap
from jax import random

In [9]:
#creating random numbers
#one of the few things where jax and numpy defer in syntax

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [10]:
size = 30
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

#We added that block_until_ready because JAX uses asynchronous execution by default (see {ref}async-dispatch).

27.3 µs ± 363 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**Asynchronous Dispatch**

When an operation such as `jnp.dot(x, x)` is executed, JAX does not wait for the operation to complete before returning control to the Python program. Instead, JAX returns a `DeviceArray` value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. We can inspect the `shape` or `type` of a DeviceArray without waiting for the computation that produced it to complete, and we can even pass it to another JAX computation, as we do with the addition operation here. Only if we actually inspect the value of the array from the host, for example by printing it or by converting it into a plain old `numpy.ndarray` will JAX force the Python code to wait for the computation to complete.

This is called asynchronous dispatch. 

Why is it useful?<br>
It allows the program to run ahead of an accelerator device, "in the case where the host program does not actually need to inspect the output of the specific accelerator computation". This allows the program enqueue arbitrary amounts of work on the accelerator and make more efficient use of its time.

**NOTE: The below exercise will only make sense with a GPU**

In [11]:
x = random.uniform(random.PRNGKey(0), (1000, 1000))
%time jnp.dot(x, x)  

CPU times: user 41.1 ms, sys: 1.73 ms, total: 42.8 ms
Wall time: 14.6 ms


Array([[255.01973, 246.64864, 254.13373, ..., 233.67952, 247.68939,
        238.36853],
       [262.65982, 253.28915, 259.18246, ..., 239.03183, 253.16756,
        249.44124],
       [259.38916, 252.72754, 258.2306 , ..., 237.83559, 252.41093,
        246.62468],
       ...,
       [256.1581 , 250.092  , 254.72174, ..., 239.23874, 247.72684,
        244.16638],
       [268.22662, 258.91205, 262.33398, ..., 245.26648, 259.0539 ,
        258.337  ],
       [254.16138, 251.7543 , 256.083  , ..., 238.59848, 245.62595,
        240.22354]], dtype=float32)

Wall time is surprisingly low, because the time was estimated not for the computation but estimated for the dispatch to accelerator. To check actual time, perform a inspection level task (convert to numpy array) or `block_until_ready()` 

In [12]:
%time jnp.dot(x, x).block_until_ready() 

CPU times: user 39.8 ms, sys: 2.66 ms, total: 42.5 ms
Wall time: 7.58 ms


Array([[255.01973, 246.64864, 254.13373, ..., 233.67952, 247.68939,
        238.36853],
       [262.65982, 253.28915, 259.18246, ..., 239.03183, 253.16756,
        249.44124],
       [259.38916, 252.72754, 258.2306 , ..., 237.83559, 252.41093,
        246.62468],
       ...,
       [256.1581 , 250.092  , 254.72174, ..., 239.23874, 247.72684,
        244.16638],
       [268.22662, 258.91205, 262.33398, ..., 245.26648, 259.0539 ,
        258.337  ],
       [254.16138, 251.7543 , 256.083  , ..., 238.59848, 245.62595,
        240.22354]], dtype=float32)

In [13]:
from jax import device_put
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

27.7 µs ± 265 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


The output of `jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of `jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster.

**1. Using `jax.jit` to speed up functions**

JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that.

In [14]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

2.11 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

565 µs ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


This means the first time selu is called, it will be jit compiled and then cached for other use

**2. Using `jax.grad` for taking derivatives**

Unlike pytorch, jax is more functional in its approach to gradient. 

In [16]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [17]:
# to verify
def derivative_sum_log(x):
    return jnp.exp(-x)/ (1.0 + jnp.exp(-x))**2
print(derivative_sum_log(x_small))

[0.25       0.19661196 0.10499357]


In [18]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


For more advanced autodiff, you can use `jax.vjp` for reverse-mode vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:

In [19]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

**3. Auto-vectorization with `jax.vmap`**

`~jax.vmap` is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand.

This means that instead of just applying `func(var) for var in mapped(vars)`

In [20]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

How would you apply the `apply_matrix` function to each of the 10 vectors in batched_x?

In [21]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
466 µs ± 6.94 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
18.4 µs ± 160 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# 3. Neural Networks and Data-loading

https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb

In [23]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Important to note that any library that works with numpy to make networks can be used, but in this notebook things are kept simple.

In [24]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer


def random_layer_params(m: 'inputdim', 
                        n: 'outputdim', 
                        key, scale=1e-2): #scale being used as standard deviation to scale std normal
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
# Waav
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

The purpose of `random.split` is to allow you to generate an arbitrary number of independent pseudorandom values given a single key. You can certainly directly create as many random keys as you wish, but requiring this in every case would be problematic.

As an example, consider an iterative solver that creates a new random vector each iteration. If keys had to be constructed manually, the user would have to pass-in an arbitrarily large number of random keys to the solver. With `random.split`, passing a single key is sufficient, and the function can split it as needed to generate as many independent pseudorandom values as needed.

**Auto-batching predictions**

In [25]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

# a function writtine considering
# only one image at a time 
def predict(params, image):
    # per-example predictions
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [26]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [27]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28)) # 10 images at once
try:
    preds = predict(params, random_flattened_images)
except TypeError:
    print('Invalid shapes!')

Invalid shapes!


In [28]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0)) # we want the first dim to be parallelized (784)


# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [29]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [30]:
# import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
# tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

ModuleNotFoundError: No module named 'tensorflow_datasets'

In [32]:
data_dir = '/Users/abhinavrao/Abhinav/lie/tfdata/tfds'
import tensorflow_datasets as tfds

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

ModuleNotFoundError: No module named 'tensorflow_datasets'

In [1]:
import time
num_epochs = 10


def get_train_batches():
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(batch_size).prefetch(1)
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in get_train_batches():
        x = jnp.reshape(x, (len(x), num_pixels))
        y = one_hot(y, num_labels)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

NameError: name 'data_dir' is not defined

# 4. Jax - The Sharp Bits

https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb#scrollTo=cDpQ5u63Ba_H

Some under-the-hood tidbits about jax that are good to know as you can not just rely on errors. JAX is made to operate just like numpy other than the extra transformations. But JAX's approach causes it to deviate sometimes, this can cause issues if not addressed.

In [26]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

**1. JAX's functional approach**

JAX's jit optimizes compilation but expects the functions to be pure. A functionally pure function does not depend on a global state and gives the same output every time for the same inputs.

In [27]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [28]:
g = 0.
def impure_uses_globals(x):
    return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [29]:
g = 0.
def impure_saves_global(x):
    global g
    g = x
    return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


**NOTE**: A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state. So its fine to 

In [30]:
def pure_uses_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state['even' if i % 2 == 0 else 'odd'] += x
    return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

50.0


It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.

In [31]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0


**2. JAX's immutability**

In numpy we can:

In [32]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


In [33]:
%xmode Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0

Exception reporting mode: Minimal


TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX not allowing mutability is to avoid issues with its functional approach. Array updates in JAX:

In [34]:
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)

updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


**NOTE:** JAX's updating is out-of-place to preserve functional-ness. New array, new pointer. But it does optimize and update in-place in a `jit` context to save space.

In [23]:
print("original array unchanged:\n", jax_array)

original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


Index-based operations

In [25]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


Another issue with JAX's approach  