# Notebook with Optax, Equinox, and JAX

In [2]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import functools
import jaxopt

In [3]:
@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
  return jnp.dot(params, x)

@jax.jit
def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = jnp.mean(optax.l2_loss(y_pred, y))
  return loss

In [4]:
key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [5]:
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

# Initialize parameters of the model + optimizer.
init_params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(init_params)

In [6]:
# A simple update loop.
params = init_params
for i in range(3):
  loss_value, grads = jax.value_and_grad(compute_loss)(params, xs, ys)
  print(f"Loss: {loss_value}")
  if loss_value < 1e-15:
    break
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

Loss: 0.3705595135688782
Loss: 0.237158864736557
Loss: 0.13446101546287537


In [7]:
params

Array([0.29512683, 0.2951268 ], dtype=float32)

## Introduce Equinox and a Model class

In [8]:
class Model(eqx.Module):
    xs: jnp.ndarray
    ys: jnp.ndarray

    def __init__(self, seed: int=42, target_params: float=0.5):
        key = jax.random.PRNGKey(seed)
        target_params = 0.5

        # Generate some data.
        self.xs = jax.random.normal(key, (16, 2))
        self.ys = jnp.sum(xs * target_params, axis=-1)

    def model_network(self, params):
        return network(params, self.xs)

In [9]:
@jax.jit
def loss_function(params, model):
  y_pred = model.model_network(params)
  y_target = model.ys
  loss = jnp.mean(optax.l2_loss(y_pred, y_target))
  return loss

In [10]:
model = Model()

In [11]:
model.ys.shape

(16,)

## JAXOPT LBFGSB

In [14]:
params = init_params
solver_2 = jaxopt.LBFGS(fun=loss_function, maxiter = 100, verbose=True)
solver_2.run(init_params=init_params, model=model)

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.LBFGS: Iter: 1 Gradient Norm (stop. crit.): 0.7238786816596985 Objective Value:0.14937657117843628  Stepsize:1.0  Number Linesearch Iterations:1 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.LBFGS: Iter: 2 Gradient Norm (stop. crit.): 0.052312254905700684 Objective Value:0.0011559441918507218  Stepsize:1.0  Number Linesearch Iterations:1 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.LBFGS: Iter: 3 Gradient Norm (stop. crit.): 0.014985358342528343 Objective Value:0.00010153277253266424  Stepsize:1.0  Number Linesearch Iterations:1 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors 

OptStep(params=Array([0.4999901 , 0.49999213], dtype=float32), state=LbfgsState(iter_num=Array(4, dtype=int32, weak_type=True), value=Array(1.2456525e-10, dtype=float32), grad=Array([-1.7826915e-05, -9.2044775e-06], dtype=float32), stepsize=Array(1., dtype=float32), error=Array(2.0062933e-05, dtype=float32), s_history=Array([[ 0.9058708 ,  0.5763672 ],
       [-0.38845462, -0.11766708],
       [-0.01448965,  0.02802739],
       [-0.00293642,  0.01326463],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ]], dtype=float32), y_history=Array([[ 1.6217207 ,  0.68388146],
       [-0.6877681 , -0.15165024],
       [-0.02373748,  0.02979416],
       [-0.00436213,  0.01433262],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0

In [11]:
# A simple update loop.
params = init_params
for i in range(1000):
  loss_value, grads = jax.value_and_grad(loss_function)(params, model)
  print(f"Loss {i}: {loss_value}")
  if loss_value < 1e-15:
    break
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

Loss 0: 0.3705595135688782
Loss 1: 0.12533925473690033
Loss 2: 0.0013298687990754843
Loss 3: 0.06173867732286453
Loss 4: 0.17377790808677673
Loss 5: 0.19612154364585876
Loss 6: 0.1262005716562271
Loss 7: 0.03880932182073593
Loss 8: 0.00013180097448639572
Loss 9: 0.02495267614722252
Loss 10: 0.07474566251039505
Loss 11: 0.09856395423412323
Loss 12: 0.07876671850681305
Loss 13: 0.03622499853372574
Loss 14: 0.004622247070074081
Loss 15: 0.0033412498887628317
Loss 16: 0.025381486862897873
Loss 17: 0.04659378528594971
Loss 18: 0.04799724742770195
Loss 19: 0.030167732387781143
Loss 20: 0.00887465849518776
Loss 21: 8.324929012815119e-07
Loss 22: 0.0071021090261638165
Loss 23: 0.02030024118721485
Loss 24: 0.0263848677277565
Loss 25: 0.020409852266311646
Loss 26: 0.008420398458838463
Loss 27: 0.0005593497771769762
Loss 28: 0.0019398077856749296
Loss 29: 0.00903375819325447
Loss 30: 0.013975651003420353
Loss 31: 0.012135028839111328
Loss 32: 0.005696813575923443
Loss 33: 0.0006245278054848313
Lo