In [None]:
import jax

# JAX tutorial

# Plan

* what is jax?
* what is jax in practice?

# JAX philosophy

* What is a good framework for machine learning?

* The one that you don't need to learn!

* Python (i.e. numpy)

* Just plain math

# JAX as numpy

In [None]:
import numpy as np

      ↓

In [None]:
import jax.numpy as jnp

# Thanks for your attention!

## Any questions?

# Example 1

In [None]:
a = jnp.ones((2, 8192), dtype=jnp.float32)
b = jnp.ones((8192, 4), dtype=jnp.float32)
a.dot(b)

# Example 2: computation devices

In [None]:
jax.devices()

Let's see what happens in Google Colab

# Example 3: differentiation

In [None]:
from jax import grad
def f(x):
    if x > 0:
        return x * 2
    else:
        return x ** 2

In [None]:
print('value', f(3.0), 'gradient', grad(f)(3.0))

In [None]:
print('value', f(-3.0), 'gradient', grad(f)(-3.0))

# What did just happen?

We **created** a derivative as a **function** and then we evaluated that dervative function  

![gradient](gradient.png)

# Automatic differentiation in JAX

* We can differential ANY computation!

* Or can't we?

* What if my code e.g. goes to the internet?

# Automatic differentiation in JAX

So how can we tell that the computation is differentiable?

The answer is...

Functional programming!

# JAX as a functional framework

For us it’s enough to say that

* Functions don’t have side effects - this is what you need to do


* Some functions input and output other functions  (e.g. jax.grad) - all of them are implemented by jax


# JAX as a functional framework

All functions are "pure", i.e. mathematical ones instead of real world impacting ones

In [None]:
state = 0
# This function is "impure"
def f(x):
    global state
    state = state + 1
    return state + x
print(f(1), f(1))

# JAX as a functional framework

In [None]:
# This function is "pure"
def f(x):
    return x * 2
print(f(1), f(1))

In [None]:
state = 0
# This function is also "pure"
def f(x):
    return x + state
print(f(1), f(1))

# JAX as a functional framework

Any function can be "purified". State here is "how many times we called the function"

In [None]:
state = 0
# before
def impure_f(x):
    global state
    state = state + 1

    return state, x
# after
def pure_f(state_, x):
    state_ = state_ + 1
    return state_, x

In [None]:
print('init state', state,
      ', first impure call', impure_f(0), ', new_state', state)

# JAX as a functional framework

Second reason - transformations of computation

In [None]:
from jax import grad # differentiation: (f) -> (gradient of f)
from jax import jit # speed up or "compile" computation: (f) -> (faster f)

These transformations assume pure computation!

# JAX as a functional framework

Let's see what happens if we compile pure and impure functions. The state is "how many times we called this function"

In [None]:
print('state', state)
fast_impure_f = jit(impure_f)
print('first "impure" call', fast_impure_f(0)[0], '; second "impure" call', fast_impure_f(0)[0])

The state "froze" at its value at the compilation moment!

# JAX as a functional framework

With pure functions this can be done in "mathematical" style

In [None]:
fast_pure_f = jit(pure_f)

initial_state = 0
first_state, first_result = fast_pure_f(initial_state, 0)
second_state, second_result = fast_pure_f(first_state, 0)
print('first state', first_state, 'second state', second_state)

# JAX as a functional framework

Why do we need this? Jit makes everything very fast!

In [None]:
def power(a):
    accum = a
    for j in range(999):
        accum = accum.dot(a)
    return accum

In [None]:
%%timeit
result = power(jnp.eye(5))

In [None]:
%%time
compiled_power = jit(power)
result = compiled_power(jnp.eye(5))

In [None]:
%%timeit
result = compiled_power(jnp.eye(5)).block_until_ready()

# JAX as a functional framework

Array assignment

In [None]:
a = jnp.zeros((3, 3))
# this does not work!
# a[:, 2:] = 1

In [None]:
# do this instead
a.at[:, 2:].set(1)

**This concept is used in homework, remember it!**

# JAX as a functional framework. Summary:

We need the functional paradigm for two reasons:

* function compile into fast computation
* we can be transformed into other functions

# So what is JAX?

JAX is
* JIT compilation
* Autodiff
* XLA

# Questions of Part 1?

# Part two: JAX in practice

Let's build a neural network!

# Random numbers in JAX

In [None]:
def getRandomNumber() -> int:
  return 4 # I tossed a fair dice so this number is random!

# Random numbers in JAX

We use "random states" (or keys) for random numbers

In [None]:
key = jax.random.key(43)
sample = jax.random.randint(key, shape=(1,), minval=0, maxval=6)
print(sample)

Let's draw one more sample

In [None]:
sample = jax.random.randint(key, shape=(1,), minval=0, maxval=6)
print(sample)

# Random generators are pure functions!

# Random numbers in JAX

In [None]:
def roll_dice(key, maxval):
    key, subkey = jax.random.split(key, 2)
    sample = jax.random.randint(subkey, shape=(1,), minval=0, maxval=maxval)
    return key, sample

In [None]:
key = jax.random.key(43)
first_key, first_sample = roll_dice(key, maxval=6)
print('first sample', first_sample)
second_key, second_sample = roll_dice(first_key, maxval=6)
print('second sample', second_sample)

# Random numbers in JAX

As you can see, we use the same principle for stateful computations:

`previous_state, input -> f -> new_state, output`

# VMAP in JAX

This is the most powerful transformation in jax

Roughly speaking, vmap is

say we have $f : \mathbb{R} \rightarrow \mathbb{R}$

then $\text{vmap}(f) : \mathbb{R}^n \rightarrow \mathbb{R}^n$

such that $f(x_k) = \text{vmap}(f)(x)_k $

vmap **vectorizes the computation**

# VMAP Example

In [None]:
def inner_prod(a, b):
    return (a * b).sum()

In [None]:
a = jnp.ones((3,))
b = jnp.ones((3,))
print(inner_prod(a, b))

In [None]:
# inner_prod - a : (n,) b : (n,) -> (1,)
matrix_matrix_prod = jax.vmap(inner_prod)
# matrix_matrix_prod - a : (m, n) b : (m, n) -> (m,)
matrix_matrix_prod(jnp.ones((3, 4)), jnp.ones((3, 4)))

In [None]:
# inner_prod - a : (n,) b : (n,) -> (1,)
matrix_vector_prod = jax.vmap(inner_prod, in_axes=(0, None), out_axes=0)
# matrix_vector_prod - a : (m, n) b : (n,) -> (m,)
matrix_matrix_prod = jax.vmap(matrix_vector_prod, in_axes=(None, 1), out_axes=1)
# matrix_matrix_prod - a : (m, n) b : (n, k) -> (m, k)
matrix_matrix_prod(jnp.ones((3, 4)), jnp.ones((4, 3)))

In [None]:
# compare that with a proper dot product
jnp.ones((3, 4)) @ jnp.ones((4, 3))

# Sharp bits of VMAP

Can we vmap `if`?

This does not work automatically - we need to reduce it to array manipulation!

In [None]:
# before
def f(x): # x can only be scalar here
    if x > 0:
        return x * 2
    else:
        return x ** 2
# after
def f(x):
    return jnp.where(x > 0, x * 2, x ** 2)

**You'll need to use this idea in homework, remember it!**

# Now let's solve something real!

Least squares problem
$$ \arg\min_{x} \|A x - b\|_2$$
Through optimization, where $A \in \mathbb{R}^{n,m}; x \in \mathbb{R}^{m}; b \in \mathbb{R}^{n}$

In [None]:
# create target values and oracle solution
key = jax.random.key(1)

key, subkey = jax.random.split(key, 2)
A = jax.random.normal(subkey, shape=(3,3))
A += jnp.eye(3) * 5 # make the optimization well conditioned

key, subkey = jax.random.split(key, 2)
b = jax.random.normal(subkey, shape=(3,)) + 1

# oracle solution
x_oracle, _, _, _ = np.linalg.lstsq(A, b)

In [None]:
key, subkey = jax.random.split(key, 2)
X = jax.random.normal(subkey, shape=(3,))

@jax.jit
def update_step(X, A, b):
    def objective(x):
        # A - [n, m], x - [m], b - [n]
        residual = A.dot(x) - b
        return (residual ** 2).sum() ** 0.5
    criterion, gradient = jax.value_and_grad(objective)(X)
    X = X - 0.0001 * gradient
    return X, criterion

In [None]:
from tqdm import tqdm
for _ in tqdm(range(10000)):
    X, loss = update_step(X, A, b)
print(loss)

In [None]:
print('oracle solution', x_oracle, '\njax optimization solution', X)
print('diff', x_oracle - X)

# Now let's solve something even more real!

In [None]:
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
print('train input shape:', x_train.shape)
print('train output shape:', y_train.shape)
print('test input shape:', x_test.shape)
print('test output shape:', y_test.shape)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(x_train[101])

In [None]:
import flax.linen as nn # for neural networks
import optax # for optimization
from functools import partial

## Define Neural Network

In [None]:
class CNN(nn.Module):
    num_features: int = 32
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.num_features, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.num_features, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=10)(x) # logits for softmax
        return x

In [None]:
model = CNN()
dummy_input = jnp.ones((1, 28, 28, 1))  # (N, H, W, C) format
# initialize
key = jax.random.key(0)
key, subkey = jax.random.split(key, 2)
parameters = model.init(subkey, dummy_input)

What does `model` hold?

In [None]:
model

What does `parameters` hold?

In [None]:
parameters

how jax manipulates them?

In [None]:
jax.tree.map(lambda x: x+ 100, parameters)

Jax treats all python objects as either arrays or PyTrees (e.g. dicts of arrays)

`jax.tree.map` is a convenient way for PyTree manipulation.
**Note that you'll need to use this function in homework!!!**

we can combine two or more PyTrees of the same structure.

In [None]:
jax.tree.map(lambda x, y: x + y, parameters, parameters)

Run prediction on an untrained network

In [None]:
model.apply(parameters, dummy_input)
# maps (28, 28) image into a 10-dimensional vector

## Create optimizer

In [None]:
optimizer = optax.sgd(learning_rate=1e-3) # stochastic gradient descent

## Create Training State

In [None]:
from flax.training.train_state import TrainState

state = TrainState.create(
    apply_fn=model.apply, # specify how to run inference in neural net
    params=parameters,
    tx=optimizer,
)

## Define one training step

In [None]:
@jax.jit
def train_step(state, x, y):
    def loss_fn(params, local_x, local_y):
        preds = state.apply_fn(params, local_x)
        loss_batch = optax.softmax_cross_entropy_with_integer_labels(preds, local_y)
        return loss_batch.mean() # compute average loss on a batch
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

## Define one evaluation step

In [None]:
@jax.jit
def eval_model(state, x, y):
    def loss(params, local_x, local_y):
        preds = state.apply_fn(params, local_x)
        loss = optax.softmax_cross_entropy_with_integer_labels(preds, local_y)
        accuracy = (preds.argmax(axis=1) == y)
        return loss.mean(), accuracy.mean() # compute average loss on a batch
    loss, accuracy = loss(state.params, x, y)
    return state, loss, accuracy

# Run Training!

In [None]:
BATCH = 500
import math
for epoch in range(100):
    print('epoch', epoch)
    key, subkey = jax.random.split(key, 2)
    permutation = jax.random.permutation(subkey, len(x_train))
    epoch_x = x_train[permutation]
    epoch_y = y_train[permutation]
    for step in tqdm(range(len(x_train) // BATCH)): # be careful here, make sure to round up the number of steps
        batch_x = epoch_x[step * BATCH: (step + 1) * BATCH][..., None] / 255.
        batch_y = epoch_y[step * BATCH: (step + 1) * BATCH]#[..., None]
        state, loss = train_step(state, batch_x, batch_y)
    _, test_loss, test_acc = eval_model(state, x_test[..., None] / 255., y_test)
    print(test_loss, test_acc)

# It works!

# The last two concepts for today: scanning and dataclasses

Run the computation through the sequence of inputs "carrying" the state

In [None]:
state = jax.random.key(0)
inputs = jnp.arange(10) + 1 # from 1 to 10
def f(state, inp):
    state, current_key = jax.random.split(state, 2)
    return state, jax.random.randint(current_key, shape=(1,), minval=0, maxval=inp)
last_state, outputs = jax.lax.scan(f, state, inputs)
outputs.squeeze()

In [None]:
# equivalent code
state = jax.random.key(0)
inputs = jnp.arange(10) + 1 # from 1 to 10
outputs = []
for j in range(len(inputs)):
    state, output = f(state, inputs[j])
    outputs.append(output)
outputs = jnp.stack(outputs)
outputs.squeeze()

In practice, people `scan` over:
* train steps of neural net (i.e. the state is `TrainState`)
* layers of neural network (the state is the activation vector)
* sequences of random variable samples (the state is the random state)
* RL environment stepping (the state is the state of MDP)
* recurrent steps of recurrent neural networks (the state is the hidden state vector of RNN)

We'll see examples of this in homework (don't worry you won't need to implement any scan)

Dataclasses are simply containers for arrays and functions (we've seen one already - it is `TrainState` in `flax`, and **any class inherited from TrainState**)

In [None]:
import chex
from typing import Callable

@chex.dataclass(frozen=True)
class MyStorage:
  array_a: chex.Array
  array_b: chex.Array
  my_function: Callable

Dataclasses can be updated vie the `replace` method

In [None]:
storage = MyStorage(array_a=jnp.zeros(5), array_b=jnp.ones(2), my_function=lambda x: x)
updated_storage = storage.replace(array_a=jnp.ones(5))
print('old storage', storage)
print('updated storage', updated_storage)

# That's it!

# Congrats, you now know jax! See you on the exam

Fun facts about jax:

* JAX is very popular tool for reinforcement learning since it can speedup and parallelize the whole training
  * Neural net can be in jax
  * RL environment can be literally written in jax
  * the whole reinforcement learning can be JIT-ed as `prev_agent, environment -> RL -> new_agent, final_return`

* Gemini model created by Google is trained in implemeted and trained with pure jax, since jax abstractions work well for large scale distributed training  

* Even though it can't differentiate every computation, it can for computations we do in practice. You can compute the gradient of the final training loss after 100 epochs of training w.r.t learning rate (you'll run out of memory though, but this is another problem)

# Summary

* Jax is a functional programming framework

* It offer a variety of computation transformations, including automatic differentiation, vectorization, and compilation

* stateful computations are done via explicit "scanning" of state data: `f: (state, input) -> (state, output)`