**S02P01_tutorial_jax_as_accelerated_numpy.ipynb**

Arz

2024 APR 08 (MON)

reference:
https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# getting started with JAX NumPy

## type

In [4]:
print(type(np.arange(3)), type(jnp.arange(3)))

<class 'numpy.ndarray'> <class 'jaxlib.xla_extension.ArrayImpl'>


## ex) multiplication

In [5]:
x = jnp.arange(int(1e7))

%timeit jnp.dot(x, x).block_until_ready()

397 µs ± 56.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


result on colab:

|         | **time** |
|---------|----------|
| **CPU** | 8.8 [ms] |
| **GPU** | 853 [µs] |
| **TPU** |          |

# JAX first transformation: grad

In [6]:
def sum_of_squares(x):
    return jnp.sum(x**2)

In [7]:
sum_of_squares_grad = grad(sum_of_squares)

In [8]:
x = jnp.asarray([1., 2., 3.])

print(sum_of_squares(x))
print(sum_of_squares_grad(x))

14.0
[2. 4. 6.]


by default, jax.grad finds the gradient w.r.t. the **first** argument.

In [9]:
def f(x, y):
    return jnp.sum((3*x - y)**2)

In [10]:
df_dx = grad(f)  # 2*(3*x - y)*3

In [11]:
x = jnp.asarray([1.0, 2.0, 3.0])
y = jnp.asarray([1.1, 2.1, 3.1])

print(df_dx(x, y))
print(2*(3*x - y)*3)

[11.4      23.400002 35.4     ]
[11.4      23.400002 35.4     ]


specify the arguments to be accounted for jax.grad using **argnums**.

In [12]:
f_grad = grad(f, argnums=(0, 1))  # find gradient w.r.t. both x and y

In [13]:
x = jnp.asarray([1.0, 2.0, 3.0])
y = jnp.asarray([1.1, 2.1, 3.1])

print(f_grad(x, y))
print(2*(3*x - y)*3, 2*(3*x - y)*-1)

(Array([11.4     , 23.400002, 35.4     ], dtype=float32), Array([ -3.8,  -7.8, -11.8], dtype=float32))
[11.4      23.400002 35.4     ] [ -3.8  -7.8 -11.8]


# value and grad

when both value and grad are needed.

- ex) logging training loss

output form: (value, grad)

In [14]:
x = jnp.asarray([1.0, 2.0, 3.0])
y = jnp.asarray([1.1, 2.1, 3.1])

In [15]:
jax.value_and_grad(f)(x, y)

(Array(53.630005, dtype=float32),
 Array([11.4     , 23.400002, 35.4     ], dtype=float32))

In [16]:
jax.value_and_grad(f, argnums=(0, 1))(x, y)

(Array(53.630005, dtype=float32),
 (Array([11.4     , 23.400002, 35.4     ], dtype=float32),
  Array([ -3.8,  -7.8, -11.8], dtype=float32)))

where value is simply,

In [17]:
f(x, y)

Array(53.630005, dtype=float32)

# auxiliary data

output form: (scalar output, aux)

In [18]:
def f_with_aux(x, y):
    return f(x, y), 3*x - y

In [19]:
# grad(f_with_aux)(x, y)  # forbidden

# because f_with_aux is not a scalar function

In [20]:
grad(f_with_aux, has_aux=True)(x, y)

(Array([11.4     , 23.400002, 35.4     ], dtype=float32),
 Array([1.9, 3.9, 5.9], dtype=float32))

In [21]:
print(3*x - y)

[1.9 3.9 5.9]


# differences from NumPy

## ex) in-place modification

### side-effect

In [22]:
def modify_in_place(x):
    x[0] = 0
    return None

In [23]:
x = np.array([1, 2, 3])

modify_in_place(x)
x

array([0, 2, 3])

In [24]:
# modify_in_place(jnp.array(x))  # forbidden: throws error

### side-effect-free (functionally pure)

In [25]:
def modify_in_place(x):
    return x.at[0].set(0)

In [26]:
x = np.array([1, 2, 3])

y = modify_in_place(jnp.array(x))
y

Array([0, 2, 3], dtype=int32)

note: old array was untouched, so there is no side-effect

In [27]:
x

array([1, 2, 3])

# your first JAX training loop

In [28]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

## ex) linear regression

y_hat(theta, x) = w*x + b

theta = (w, b)

In [29]:
x = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))

y = 3*x - 1 + noise

In [30]:
fig = px.scatter(x=x, y=y)
fig.show()

In [31]:
def model(theta, x):
    w, b = theta
    return w*x + b

In [32]:
def loss_function(theta, x, y):
    y_pred = model(theta, x)
    return jnp.mean((y - y_pred)**2)

in JAX, it’s common to define an update() function that is called every step, taking the current parameters as input and returning the new parameters. this is a natural consequence of JAX’s functional nature.

In [33]:
@jit
def update_theta(theta, x, y, alpha=0.1):
    return theta - alpha*grad(loss_function)(theta, x, y)

**training**

In [34]:
theta = jnp.array([1., 1.])

for _ in range(1000):
    theta = update_theta(theta, x, y)

In [44]:
fig = px.scatter(x=x, y=y)
fig_model = px.line(x=x, y=model(theta, x))
fig_model.data[0].line.color = "#e02a19"
fig.add_trace(fig_model.data[0])

fig.show()

In [45]:
w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")

w: 3.00, b: -1.01
