**S01P02_how_to_think.ipynb**

Arz

2024 APR 04 (THU)

reference:
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

In [1]:
import jax.numpy as jnp
from jax import lax

In [2]:
import numpy as np

In [3]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

In [4]:
%xmode minimal

Exception reporting mode: Minimal


# plot test

In [5]:
x = jnp.linspace(0, 10, 1000)
y = 2*jnp.sin(x)*jnp.cos(x)

CUDA backend failed to initialize: Unable to load CUDA. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [6]:
fig = px.line(x=x, y=y, markers=True)
fig.show()

# about type

In [7]:
type(x)

jaxlib.xla_extension.ArrayImpl

## NumPy array: mutable

In [8]:
x = np.arange(10)
x[0] = 7

In [9]:
print(x)

[7 1 2 3 4 5 6 7 8 9]


## JAX array: immutable

In [10]:
x = jnp.arange(10)
# x[0] = 7  # forbidden
x = x.at[0].set(7)  # reassignment with a modified copy 

In [11]:
print(x)

[7 1 2 3 4 5 6 7 8 9]


# jax.numpy vs jax.lax

## ex) promotion

In [12]:
jnp.add(1, 1.0)  # automatic promotion

Array(2., dtype=float32, weak_type=True)

In [13]:
# lax.add(1, 1.0)  # forbidden

In [14]:
lax.add(jnp.float32(1), 1.0)

Array(2., dtype=float32)

## ex) more general operations allowed by lax API

In [15]:
x = jnp.array([1, 2, 1])
y = jnp.ones(5)

In [16]:
z = jnp.convolve(x, y)

print(z)

[1. 3. 4. 4. 4. 3. 1.]


In [17]:
z = lax.conv_general_dilated(
    x.reshape(1, 1, len(x)).astype(float),  # explicit promotion
    y.reshape(1, 1, len(y)),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])

print(z)
print(z[0, 0])

[[[1. 3. 4. 4. 4. 3. 1.]]]
[1. 3. 4. 4. 4. 3. 1.]


# to JIT or not to JIT

In [18]:
from jax import jit

## to JIT

In [19]:
def norm(X):
    X = X - X.mean(0)
    return X/X.std(0)

In [20]:
norm_jit = jit(norm)

### comparison: result

In [21]:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))

In [22]:
np.allclose(norm(X), norm_jit(X), atol=1E-6)

True

### comparison: execution time

In [23]:
%timeit norm(X).block_until_ready()
%timeit norm_jit(X).block_until_ready()

1.29 ms ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.25 ms ± 666 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## not to JIT

In [24]:
def get_negatives(x):
    return x[x < 0]

In [25]:
x = jnp.array(np.random.randn(10))

In [26]:
get_negatives(x)

Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

In [27]:
get_negatives_jit = jit(get_negatives)

In [28]:
# get_negatives_jit(x)  # forbidden because the shape of x is not static

# JIT mechanics: tracing & static variables

JIT traces a function.

In [29]:
@jit
def f(x, y):
    print("f():")
    print(f"  x = {x}")
    print(f"  y = {y}")

    result = jnp.dot(x + 1, y + 1)

    print(f"  result = {result}")
    
    return result

In [30]:
x = np.random.randn(3, 4)
y = np.random.randn(4)
print(x, y)

f(x, y)

[[ 0.24124517 -1.25714202 -0.48511598 -0.98639282]
 [ 1.39783022  0.48784978  1.90996403 -0.26037156]
 [-0.49505737  1.34450656  0.59428027  0.61083763]] [-0.57855466  1.05148987 -0.48361592 -1.27473592]
f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

tracer objects are printed. this is how JIT views while tracing a function.

In [31]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print(x2, y2)

f(x2, y2)

[[-0.82886446 -0.96895962  0.31626638 -1.30240421]
 [ 1.42542593 -0.01968225 -1.03792854 -1.98280329]
 [ 0.29157075 -0.1554043   0.87975108 -0.14639181]] [0.88247808 0.76965426 0.25971301 0.98657674]


Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

tracer objects are not printed. 

since JIT already saw and compiled for this specific shape and type of the inputs, it just cache and run the compiled code of the function.

In [32]:
x3 = np.random.randn(2, 4)
y3 = np.random.randn(4)
print(x3, y3)

f(x3, y3)

[[-0.16529104  0.86774901  0.879805   -0.5616348 ]
 [ 1.26464251 -0.02158209 -0.37602974 -0.82070135]] [-1.50348735 -0.18001363  0.28808094 -0.11594711]
f():
  x = Traced<ShapedArray(float32[2,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>


Array([3.9201422, 0.6243042], dtype=float32)

tracer objects are printed. 

the shape and type of the inputs have changed, JIT has to trace the function for this new shape and type of the inputs.

## view JAX expression (jaxpr)

expression encoding extracted sequence of operations.

In [33]:
from jax import make_jaxpr

In [34]:
make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3] = pjit[
      name=f
      jaxpr={ lambda ; d:f32[3,4] e:f32[4]. let
          f:f32[3,4] = add d 1.0
          g:f32[4] = add e 1.0
          h:f32[3] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] f g
        in (h,) }
    ] a b
  in (c,) }

In [35]:
make_jaxpr(f)(x3, y3)

{ lambda ; a:f32[2,4] b:f32[4]. let
    c:f32[2] = pjit[
      name=f
      jaxpr={ lambda ; d:f32[2,4] e:f32[4]. let
          f:f32[2,4] = add d 1.0
          g:f32[4] = add e 1.0
          h:f32[2] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] f g
        in (h,) }
    ] a b
  in (c,) }

## partial tracing

some arguments are value-dependent and hence not traceable.

In [36]:
@jit 
def f(x, make_negative):
    return -x if make_negative else x

In [37]:
# f(7, True)  # forbidden because flow control using boolean argument is value-dependent

In [38]:
from functools import partial

In [39]:
@partial(jit, static_argnums=(1,))  # set the boolean argument as static w.r.t. jit.
def f(x, make_negative):
    return -x if make_negative else x

In [40]:
f(7, True)  # now works

Array(-7, dtype=int32, weak_type=True)

# static vs traced operations

In [44]:
def f(x):
    return x.reshape(np.array(x.shape).prod())

In [50]:
x = np.ones((2, 3))

print(np.array(x.shape).prod())
print(np.array((2, 3)).prod())
f(x)

6
6


array([1., 1., 1., 1., 1., 1.])

## forbidden case

In [51]:
@jit
def f(x):
    return x.reshape(jnp.array(x.shape).prod())

In [53]:
x = jnp.ones((2, 3))

# f(x)  # forbidden

x is traced, x.shape is static (?).

but, jnp.array() and .prod() makes x.shape to be a traced value (?).

.reshape(<static shape>) is required because array shape must be static.

so .reshape(<traced value>) is forbidden.


## fixed case

- **numpy** for static operations.
- **jax.numpy** for traced operations.

In [54]:
@jit
def f(x):
    return x.reshape((np.prod(x.shape),))

In [58]:
x = jnp.ones((2, 3))

f(x)  # now works

Array([1., 1., 1., 1., 1., 1.], dtype=float32)