# 1st Equinox tutorial

In [1]:
!which python

/Users/tristantorchet/Desktop/Code/VSCode/LearningJAX/.venv/bin/python


# This notebook will follow Equinox official tutorial. It will go in depth and most likely will link to Python's core functionnalities (e.g dataclasses, abstract classes, ...)

# 1. Linear Layer and autograd

## 1.1. Autograd

In [6]:
import equinox as eqx
import jax
import jax.numpy as jnp

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

In [7]:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

In [8]:
print(grads)

Linear(weight=f32[3,2], bias=f32[3])


## 1.2. Pytrees

### 1.2.1. JAX core: PyTrees

By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:

In [42]:
# create a dummy model as a ABC class with dataclass
from dataclasses import dataclass
from abc import ABC, abstractmethod

# AbstractLayer is an abstract class which cannot be instantiated
# ABC classes with abstract methods insure that all subclasses implement the abstract methods
# (It doesn't make sense to use the @dataclass decorator with an ABC class)
class AbstractLayer(ABC):
    @abstractmethod
    def __call__(self, x):
        pass

@dataclass
class CustomLayer(AbstractLayer):
    weight: jax.Array
    bias: jax.Array
    name: str

    def __call__(self, x):
        return self.weight @ x + self.bias

@jax.tree_util.register_pytree_node_class
@dataclass
class PyTreeCustomLayer(AbstractLayer):
    weight: jax.Array
    bias: jax.Array
    name: str

    def __call__(self, x):
        return self.weight @ x + self.bias
    
    def tree_flatten(self):
        children = (self.weight, self.bias) # the children of the current node
        aux_data = (self.name,) # auxiliary data that are not part of the tree structure
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, *aux_data)


w = jax.random.normal(jax.random.PRNGKey(0), (3, 2))
b = jax.random.normal(jax.random.PRNGKey(1), (3,))

custom_layer = CustomLayer(w, b, 'layer1')
pytree_custom_layer = PyTreeCustomLayer(w, b, 'layer1')
show_example(custom_layer)
print('\n')
show_example(pytree_custom_layer)

structured=CustomLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')
  flat=[CustomLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')]
  tree=PyTreeDef(*)
  unflattened=CustomLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')


structured=PyTreeCustomLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')
  flat=[Array([[ 0.18784384, -1

In [43]:
leaves, aux_data = pytree_custom_layer.tree_flatten()
print(leaves)
print(aux_data)

layer_reconstructed = PyTreeCustomLayer.tree_unflatten(aux_data, leaves) # will work
print(layer_reconstructed)

(Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32))
('layer1',)
PyTreeCustomLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')


Python objects: https://docs.python.org/3/reference/datamodel.html#basic-customization, https://stackoverflow.com/questions/73409385/object-class-documentation-in-python, https://docs.python.org/3/library/functions.html#object

### 1.2.2. All equinox modules ```eqx.Modules``` are Pytrees

In [46]:
batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))

print(type(model))
print(model)
print(model.weight)
print(model.bias)

<class '__main__.Linear'>
Linear(weight=f32[3,2], bias=f32[3])
[[0.59902626 0.2172144 ]
 [0.660603   0.03266738]
 [1.2164948  1.1940813 ]]
[ 1.1378784  -1.2209548  -0.59153634]


In [52]:
print(type(model))
print(type(model).__bases__)
print(type(model).__bases__[0].__bases__)

<class '__main__.Linear'>
(<class 'equinox._module.Module'>,)
(<class 'object'>,)


In [45]:
leaves, aux_data = jax.tree_util.tree_flatten(model)
print(leaves, '\n')
print(aux_data)

[Array([[0.59902626, 0.2172144 ],
       [0.660603  , 0.03266738],
       [1.2164948 , 1.1940813 ]], dtype=float32), Array([ 1.1378784 , -1.2209548 , -0.59153634], dtype=float32)] 

PyTreeDef(CustomNode(Linear[('weight', 'bias'), (), ()], [*, *]))


In [53]:
tree_mapped_fn = lambda x: x.shape 
print(jax.tree_util.tree_map(tree_mapped_fn, jax.tree_util.tree_flatten(model, is_leaf=lambda x: isinstance(x, jax.Array))[0]))
print(jax.tree_util.tree_map(tree_mapped_fn, model))

[(3, 2), (3,)]
Linear(weight=(3, 2), bias=(3,))


In [None]:
batched_input = jax.random.normal(jax.random.PRNGKey(0), (batch_size, in_size))



TypeError: dot_general requires contracting dimensions to have the same shape, got (2,) and (32,).

In [58]:
try: 
    print(model(batched_input))
except Exception as e:
    print(f'ERROR: {e}')
    print(f'w @ x = {model.weight.shape}, {batched_input.shape}')

ERROR: dot_general requires contracting dimensions to have the same shape, got (2,) and (32,).
w @ x = (3, 2), (32, 2)


In [68]:
out = jax.vmap(model)(batched_input)
print(out.shape)
out_manual = batched_input @ model.weight.T + model.bias
print(out_manual.shape)
print(jnp.allclose(out, out_manual, atol=1e-2))

(32, 3)
(32, 3)
True


In [70]:
import time

start = time.time()
out = jax.vmap(model)(batched_input)
end = time.time()
print(f'vmap time: {(end - start)*1e3:.2f} ms')

start = time.time()
out_manual = batched_input @ model.weight.T + model.bias
end = time.time()
print(f'manual time: {(end - start)*1e3:.2f} ms')

vmap time: 1.76 ms
manual time: 0.50 ms


In [71]:
!which python

  pid, fd = os.forkpty()


/Users/tristantorchet/Desktop/Code/VSCode/QSSM/.venv/bin/python


In [72]:
!pip install jaxtyping

  pid, fd = os.forkpty()




In [None]:
class MLP(eqx.Module):
    layer1: eqx.nn.Linear
    bn1: eqx.nn.BatchNorm
    layer2: eqx.nn.Linear
    bn2: eqx.nn.BatchNorm
    layer3: eqx.nn.Linear

    def __init__(self, in_size, hidden_size, out_size, key):
        self.layer1 = eqx.nn.Linear(in_size, hidden_size, key)
        self.bn1 = eqx.nn.BatchNorm(hidden_size, axis_name='batch')
        self.layer2 = eqx.nn.Linear(hidden_size, hidden_size, key)
        self.bn2 = eqx.nn.BatchNorm(hidden_size, axis_name='batch')
        self.layer3 = eqx.nn.Linear(hidden_size, out_size, key)

    def __call__(self, x, state):
        x = self.layer1(x)
        x = jax.nn.relu(x)
        x = self.layer2(x)
        x = jax.nn.relu(x)
        x = self.layer3(x)
        return x

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax  # https://github.com/deepmind/optax

In [2]:
# This model is just a weird mish-mash of stateful and non-stateful layers for
# demonstration purposes, it isn't doing any clever.
class Model(eqx.Module):
    norm1: eqx.nn.BatchNorm
    spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
    norm2: eqx.nn.BatchNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jr.split(key, 4)
        self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
        self.spectral_linear = eqx.nn.SpectralNorm(
            layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
            weight_name="weight",
            key=key2,
        )
        self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)

    def __call__(self, x, state):
        x, state = self.norm1(x, state)
        x, state = self.spectral_linear(x, state)
        x = jax.nn.relu(x)
        x, state = self.norm2(x, state)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state

In [3]:
def compute_loss(model, state, xs, ys):
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(xs, state)
    loss = jnp.mean((pred_ys - ys) ** 2)
    return loss, state


@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
    grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state

In [5]:
dataset_size = 10
learning_rate = 3e-4
steps = 5
seed = 5678

key = jr.PRNGKey(seed)
mkey, xkey, xkey2 = jr.split(key, 3)

model, state = eqx.nn.make_with_state(Model)(mkey)
print(state)

xs = jr.normal(xkey, (dataset_size, 3))
ys = jnp.sin(xs) + 1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

for _ in range(steps):
    # Full-batch gradient descent in this simple example.
    model, state, opt_state = make_step(model, state, opt_state, xs, ys)
    print(state)

State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
State(
  0x1015c5948=bool[],
  0x1015c5968=(f32[3], f32[3]),
  0x1015c5988=(f32[32], f32[3]),
  0x1015c59a8=bool[],
  0x1015c59c8=(f32[32], f32[32])
)
