In [51]:
import jax
from jax import numpy as jnp
import numpy as np
from functools import partial
import equinox as eq
from jax.config import config
#config.update("jax_debug_nans", True)
#config.update("jax_disable_jit", True)

In [53]:
def quadratic_function(x):
    return x.T@x

In [54]:
quadratic_function(jnp.array((1,2,3)))

Array(14, dtype=int32)

In [42]:
# JIT makes everything faster
@eq.filter_jit
def gradient_descent(lr,t,function,x,kwargs=None):
    
    if kwargs is None:
        kwargs = {}

    # Take the gradient of the function
    par_fun = partial(function,**kwargs)
    grad_function = jax.value_and_grad(par_fun)
    # We create a step function which matches the API of lax.scan
    
    def step(x,_):
        v, g = grad_function(x)
        x -= lr*g
        return x, x
    
    # lax.scan is like a very fast for-loop we can jit and take a gradient through
    x, xs = jax.lax.scan(step,x,None,t)
    return x

In [74]:
gradient_descent(0.1,100,quadratic_function,jnp.array((1.,-2,1)))

Array([ 2.0370368e-10, -4.0740736e-10,  2.0370368e-10], dtype=float32)

In [69]:
def loss(lr,**kwargs):
    x = gradient_descent(lr=lr,**kwargs)
    return x.T@x

In [79]:
# We now want to do meta learning
# This means we optimize the gradient descent parameter lr using gradient descent
# For this we need to make the normal gradient descent function only take in one argument
args = {"t": 3, "function": quadratic_function, "x": jnp.array((1.,-2,1))}
lr = gradient_descent(0.0005,1000,loss,0.1,args)
print(lr)

0.37897322
