In [1]:
# All jax arrays are instances of jax.Array. But we usually create them using jax.numpy.arange or jax.numpy.zeros etc

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2,3"

In [3]:
import jax.numpy as jnp

In [8]:
x = jnp.arange(2_000_000_000)

2025-01-08 20:39:22.526050: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 466.27MiB (488917606 bytes) by rematerialization; only reduced to 7.45GiB (8000000000 bytes), down from 7.45GiB (8000000000 bytes) originally


In [9]:
x.devices()

{CudaDevice(id=0)}

In [10]:
x.sharding

SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)

In [11]:
# Jax transformations operate on Jax functions
"""
jax.jit
jax.vmap
jax.grad

Transformations can also be applied using python's decorators

@jax.jit

"""

"\njax.jit\njax.vmap\njax.grad\n\nTransformations can also be applied using python's decorators\n\n@jax.jit\n\n"

# Tracing

Tracers are abstract stand-ins for array objects, and are passed to jax functions in order to extract the sequence of operations
that the function encodes.

In [12]:
import jax

In [13]:
@jax.jit
def f(x):
    print(x)
    return x + 1

In [14]:
x = jax.numpy.arange(5.)
result = f(x)

Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace>


What is happening here is that when the function is jitted, it is to be converted to jaxpr. To do that, the compiler needs to know
the input and the operations on the input. So it assumes the input to be an array, passes a Trace object (which acts as an array),
the looks at the operations being performed and converts them to jaxpr.

Also, the print function is not present in jaxpr, so the actual x values are not printed.

In [16]:
jax.make_jaxpr(f)(x)

{ lambda ; a:f32[5]. let
    b:f32[5] = pjit[
      name=f
      jaxpr={ lambda ; c:f32[5]. let d:f32[5] = add c 1.0 in (d,) }
    ] a
  in (b,) }

Jax operates on arrays, but to treat dicts and nested structures uniformly, it uses pytree abstraction.

In [17]:
import jax.numpy as jnp

In [28]:
params = [1,2, (jnp.arange(4.0,), jnp.ones((2)))]

In [21]:
params

[1,
 2,
 (Array([0., 1., 2., 3.], dtype=float32), Array([1., 1.], dtype=float32))]

In [22]:
print(jax.tree.structure(params))

PyTreeDef([*, *, (*, *)])


In [23]:
print(jax.tree.leaves(params))

[1, 2, Array([0., 1., 2., 3.], dtype=float32), Array([1., 1.], dtype=float32)]


In [24]:
from typing import NamedTuple

In [25]:
class Params(NamedTuple):
    a: int
    b: float
    
params = Params(1, 5.0)


In [27]:
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]


In [34]:
# Pseudorandom numbers:
# Numpy relies on a global state set by numpy.random.seed()
# Since jax aims to be stateless and distributed, this does not work.
# Therefore, jax thats the seed as an explicit input.

In [35]:
from jax import random

In [53]:
key = random.key(1)

In [54]:
key

Array((), dtype=key<fry>) overlaying:
[0 1]

In [46]:
print(random.normal(key) == random.normal(key))

True


In [None]:
# NEVER REUSE KEYS unless you want identical outputs.

In [55]:
random.normal(key)

Array(-1.1842843, dtype=float32)

In [56]:
new_key, subkey = random.split(key) # Deterministic function that converts one key into two independent keys randomly.

In [57]:
new_key, subkey

(Array((), dtype=key<fry>) overlaying:
 [2441914641 1384938218],
 Array((), dtype=key<fry>) overlaying:
 [3819641963 2025898573])