In [None]:
import jax 
import jax.numpy as jnp
import time 


# JIT will  necessarily trace your self.attribute into constants so we can't jit a class method <font color='red'>if we want to modify this attribute between calls</font>
### instead we need to jit a wrapper function that calls the class method (and we can vmap the class method)

### 1. Dummy counter class

In [120]:

class Counter:
    """A simple counter."""

    def __init__(self):
        self.n = 0
    
    def count(self) -> int:
        """Increments the counter and returns the new value."""
        self.n += 1
        return self.n

    def reset(self):
        """Resets the counter to zero."""
        self.n = 0

counter = Counter()

for i in range(3):
    n = counter.count()
    print(n)

1
2
3


### 2. First tentative to speed up the code by using ```jit``` on the class method --> <font color='red'> doesn't work BUT DOESN'T RAISE AN ERROR</font>

In [121]:
counter.reset()
fast_count = jax.jit(counter.count)
for i in range(3):  # oops, it's not working as it's supposed to be
    print(f'out = {fast_count()}')

out = 1
out = 1
out = 1


### 3. Second tentative to speed up the code by using ```jit``` on a wrapper function that calls the class method --> <font color='green'> works: around 13x faster</font>
##### We need to create a wrapper function that takes the class parameters as arguments, instantiate an object and call the instance method inside the wrapper function. 

In [129]:
from functools import partial
@partial(jax.jit, static_argnums=(1, 2,))
def wrapper(params, n_iter=10000, log=False):
    counter_ = Counter()
    counter_.n = params[0]
    for _ in range(n_iter):
        n = counter_.count()
        if log: jax.debug.print('out = {n}', n=n) 
    return (counter_.n,)

counter.reset()
params = (counter.n,)

_ = wrapper(params, n_iter=3, log=True)

# warmup
_ = wrapper(params, n_iter=10000, log=False)
start = time.time()
params_updated = wrapper(params, n_iter=10000, log=False)
end = time.time()
print(f'time = {(end - start)*1000:.2f} ms')
print(f'params = {params_updated}')

out = 1
out = 2
out = 3
time = 0.07 ms
params = (Array(10000, dtype=int32, weak_type=True),)


In [130]:
start = time.time()
for i in range(10000):
    n = counter.count()
end = time.time()
print(f'time = {(end - start)*1000:.2f} ms')
print(f'out = {n}') 

time = 0.91 ms
out = 10000


### 4. Application: Linear layer as a class to mimic Equinox

#### 4.1. Dummy Linear layer class with ```jit``` on the class method --> <font color='red'> doesn't work BUT DOESN'T RAISE AN ERROR</font>

In [6]:
class Linear:
    def __init__(self, in_features, out_features):
        key = jax.random.PRNGKey(0)
        self.weight = jax.random.normal(key, (out_features, in_features))
        self.bias = jax.random.normal(key, (out_features,))

    def __call__(self, x):
        jax.debug.print('jitted self.weight = \n{weight}', weight=self.weight)
        return jnp.dot(self.weight, x) + self.bias

In [134]:
dim_in = 3
dim_out = 4
bs = 10
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (bs, dim_in))
layer = Linear(dim_in, dim_out)

In [135]:
layer_call = jax.jit(layer)
layer_call = jax.vmap(layer_call, in_axes=(0,))
print(f'out = \n{layer_call(x)}\n')

# scale the weights by 2
layer.weight = layer.weight / 2
layer.bias = layer.bias / 2
print(f'layer.weight = \n{layer.weight}')
print(f'out = \n{layer_call(x)}') # WARNING: we can see that the output is not scaled, meaning that it still uses the old weights and biases

jitted self.weight = 
[[ 1.1901639  -1.0996888   0.44367844]
 [ 0.5984697  -0.39189556  0.69261974]
 [ 0.46018356 -2.068578   -0.21438177]
 [-0.9898306  -0.6789304   0.27362573]]
out = 
[[ 3.804824    0.33519322  0.44742838 -2.7707675 ]
 [ 3.625029    0.02590603  3.410178    0.50355005]
 [ 0.8323103  -1.398576    0.28799635  0.3267359 ]
 [ 1.7365305  -0.21044779 -1.2441818  -0.9568258 ]
 [ 3.3418374   0.42115265  0.2816336  -1.6189749 ]
 [ 3.0669456  -0.31462416  0.10747854 -2.8793561 ]
 [ 1.9252988  -0.1918292  -1.0745429  -1.1775186 ]
 [ 1.645749   -0.4092752  -2.20948    -2.3588734 ]
 [ 2.2478125  -0.71958554  2.4841554   1.0517647 ]
 [ 3.4196887   0.1534388   0.6268668  -2.0282063 ]]

layer.weight = 
[[ 0.5950819  -0.5498444   0.22183922]
 [ 0.29923484 -0.19594778  0.34630987]
 [ 0.23009178 -1.034289   -0.10719088]
 [-0.4949153  -0.3394652   0.13681287]]
jitted self.weight = 
[[ 1.1901639  -1.0996888   0.44367844]
 [ 0.5984697  -0.39189556  0.69261974]
 [ 0.46018356 -2.068578   -0.

#### 4.2 Taking inspiration from equinox

##### 4.2.1: Equinox Custom Linear

In [None]:
import equinox as eqx
class EqxLinear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_features, out_features):
        key = jax.random.PRNGKey(0)
        self.weight = jax.random.normal(key, (out_features, in_features))
        self.bias = jax.random.normal(key, (out_features,))

    def __call__(self, x):
        return jnp.dot(self.weight, x) + self.bias

out=[ 0.00231779 -1.9878993  -1.513092   -1.1650387 ]


In [96]:
dim_in = 3
dim_out = 4
bs = 10

x = jax.random.normal(jax.random.PRNGKey(0), (bs, dim_in))
y = jnp.ones((bs,)).astype(jnp.int32)

In [144]:
eqx_layer = EqxLinear(dim_in, dim_out)
print(f'out={eqx_layer(x[0])}')
print(f'out = \n{jax.vmap(eqx_layer, in_axes=(0,))(x)}\n')

out=[ 3.804824    0.33519322  0.44742838 -2.7707675 ]
out = 
[[ 3.804824    0.33519322  0.44742838 -2.7707675 ]
 [ 3.625029    0.02590603  3.410178    0.50355005]
 [ 0.8323103  -1.398576    0.28799635  0.3267359 ]
 [ 1.7365305  -0.21044779 -1.2441818  -0.9568258 ]
 [ 3.3418374   0.42115265  0.2816336  -1.6189749 ]
 [ 3.0669456  -0.31462416  0.10747854 -2.8793561 ]
 [ 1.9252988  -0.1918292  -1.0745429  -1.1775186 ]
 [ 1.645749   -0.4092752  -2.20948    -2.3588734 ]
 [ 2.2478125  -0.71958554  2.4841554   1.0517647 ]
 [ 3.4196887   0.1534388   0.6268668  -2.0282063 ]]



##### 4.2.2: Custom class (registered as a PyTree) using the parameters of the Equinox Custom Linear

In [138]:
from dataclasses import dataclass
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class

@register_pytree_node_class
@dataclass
class PyTreeLinear:
    weight: jnp.ndarray
    bias: jnp.ndarray

    def __call__(self, x):
        return jnp.dot(self.weight, x) + self.bias
    
    def tree_flatten(self):
        children = (self.weight, self.bias)
        aux = None
        return children, aux
    
    @classmethod
    def tree_unflatten(cls, aux, children):
        return cls(*children)
    

In [143]:
layer = PyTreeLinear(eqx_layer.weight, eqx_layer.bias)
print(f'out = {layer(x[0])}\n')
print(f'out = \n{jax.vmap(layer, in_axes=(0,))(x)}\n')

out = [ 3.804824    0.33519322  0.44742838 -2.7707675 ]

out = 
[[ 3.804824    0.33519322  0.44742838 -2.7707675 ]
 [ 3.625029    0.02590603  3.410178    0.50355005]
 [ 0.8323103  -1.398576    0.28799635  0.3267359 ]
 [ 1.7365305  -0.21044779 -1.2441818  -0.9568258 ]
 [ 3.3418374   0.42115265  0.2816336  -1.6189749 ]
 [ 3.0669456  -0.31462416  0.10747854 -2.8793561 ]
 [ 1.9252988  -0.1918292  -1.0745429  -1.1775186 ]
 [ 1.645749   -0.4092752  -2.20948    -2.3588734 ]
 [ 2.2478125  -0.71958554  2.4841554   1.0517647 ]
 [ 3.4196887   0.1534388   0.6268668  -2.0282063 ]]



##### 4.2.3: Applying a transformation on the parameters of the custom class between calls of the wrapper function
##### bs = 1

In [146]:
@jax.jit
def loss_bs_1(params, aux, x, y):
    model_ = PyTreeLinear.tree_unflatten(aux, params)
    pred_y = model_(x)
    jax.debug.print('pred_y = {pred_y}', pred_y=pred_y)
    pred_y = jax.nn.log_softmax(pred_y)
    jax.debug.print('log_softmax_pred_y = {pred_y}', pred_y=pred_y)
    jax.debug.print('y = {y}', y=y)
    loss_ = cross_entropy_bs_1(y, pred_y)
    jax.debug.print('loss = {loss}', loss=loss_)
    return loss_

def cross_entropy_bs_1(y, pred_y):
    # y is the true target: shape=(1,).
    # pred_y ist the log-softmax'd prediction shape=(out_dim,).
    pred_y = jnp.take_along_axis(pred_y, jnp.array([y]), axis=0)
    return -jnp.mean(pred_y)

params, aux = layer.tree_flatten()

loss_, grads_ = jax.value_and_grad(loss_bs_1)(params, aux, x[0], y[0])
print(f'grads = \n{grads_}')

print('\nSCALE DOWN\n')
scale_fn = lambda x: x / (-2)
layer_scaled_dn = jax.tree_map(scale_fn, layer)
params_dn, aux_dn = layer_scaled_dn.tree_flatten()
loss_, grads_ = jax.value_and_grad(loss_bs_1)(params_dn, aux_dn, x[0], y[0])
print(f'grads = \n{grads_}')

pred_y = [ 3.804824    0.33519322  0.44742838 -2.7707675 ]
log_softmax_pred_y = [-0.06517741 -3.5348084  -3.422573   -6.640769  ]
y = 1
loss = 3.534808397293091
grads = 
(Array([[ 1.8816755e+00,  3.6422554e-01,  5.4732017e-02],
       [-1.9498295e+00, -3.7741771e-01, -5.6714401e-02],
       [ 6.5530933e-02,  1.2684460e-02,  1.9060884e-03],
       [ 2.6230200e-03,  5.0772348e-04,  7.6295393e-05]], dtype=float32), Array([ 0.9369012 , -0.9708356 ,  0.03262837,  0.00130602], dtype=float32))

SCALE DOWN

pred_y = [-1.902412   -0.16759661 -0.22371419  1.3853837 ]
log_softmax_pred_y = [-3.6586835  -1.9238681  -1.9799857  -0.37088773]
y = 1
loss = 1.9238680601119995
grads = 
(Array([[ 5.1749345e-02,  1.0016835e-02,  1.5052255e-03],
       [-1.7150941e+00, -3.3198130e-01, -4.9886689e-02],
       [ 2.7730268e-01,  5.3675946e-02,  8.0658616e-03],
       [ 1.3860421e+00,  2.6828852e-01,  4.0315602e-02]], dtype=float32), Array([ 0.02576641, -0.853959  ,  0.13807121,  0.6901214 ], dtype=float32))


##### bs > 1

In [147]:
@jax.jit
def loss(params, aux, x, y):
    model_ = PyTreeLinear.tree_unflatten(aux, params)
    pred_y = jax.vmap(model_, in_axes=(0,))(x)
    pred_y = jax.nn.log_softmax(pred_y)
    loss_ = cross_entropy(y, pred_y)
    return loss_

def cross_entropy(y, pred_y):
    # y are the true targets: shape=(bs,).
    # pred_y are the log-softmax'd prediction shape=(bs, out_dim).
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) # we need to convert y to shape=(bs, 1) to be able to use it as an index
    return -jnp.mean(pred_y)

layer = PyTreeLinear(eqx_layer.weight, eqx_layer.bias)

params, aux = layer.tree_flatten()
loss_, grads_ = jax.value_and_grad(loss)(params, aux, x, y)
print(f'loss = {loss_}')
print(f'grads = \n{grads_}')

print('\nSCALE DOWN\n')
scale_fn = lambda x: x / (-2)

layer_scaled_dn = jax.tree_map(scale_fn, layer)
params_dn, aux_dn = layer_scaled_dn.tree_flatten()
loss_, grads_ = jax.value_and_grad(loss)(params_dn, aux_dn, x, y)
print(f'grads = \n{grads_}')

loss = 3.1309831142425537
grads = 
(Array([[ 0.66311485,  0.22464275,  0.22545396],
       [-0.61407083, -0.10023337, -0.24351758],
       [-0.02413169, -0.11227638,  0.01469187],
       [-0.02491243, -0.01213299,  0.00337176]], dtype=float32), Array([ 0.7487441 , -0.945336  ,  0.1443587 ,  0.05223323], dtype=float32))

SCALE DOWN

grads = 
(Array([[ 0.01075298, -0.00505622,  0.01786483],
       [-0.5837781 , -0.19591357, -0.21283932],
       [ 0.13681674,  0.09540007,  0.09139632],
       [ 0.4362084 ,  0.10556973,  0.10357817]], dtype=float32), Array([ 0.07103135, -0.72101754,  0.21408351,  0.43590268], dtype=float32))


##### Verifying the correctness of the gradients with Equinox

In [149]:
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    pred_y = jax.nn.log_softmax(pred_y)
    loss_ = cross_entropy(y, pred_y)
    return loss_

loss_, grads_ = loss(eqx_layer, x, y)
print(f'\nloss = {loss_}')
print(f'grads.weight = \n{grads_.weight}')
print(f'grads.bias = \n{grads_.bias}')


loss = 3.1309831142425537
grads.weight = 
[[ 0.66311485  0.22464275  0.22545396]
 [-0.61407083 -0.10023337 -0.24351758]
 [-0.02413169 -0.11227638  0.01469187]
 [-0.02491243 -0.01213299  0.00337176]]
grads.bias = 
[ 0.7487441  -0.945336    0.1443587   0.05223323]


module: https://github.com/patrick-kidger/equinox/blob/d7d2cb91dde3beee970d9f8f10fc3c9c7f2f0e39/equinox/_module.py#L969
