In [22]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax.scipy.stats import norm
from jax import grad, jit, vmap
from jax import random
import jax

from typing import Sequence, Callable, Union, Tuple, Optional, Any
from jaxtyping import Array, Float, Int, PyTree

import equinox as eqx
import optax
import chex

from dataclasses import dataclass
from functools import partial

jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [23]:
def f(x):
    return jnp.sin(x)

fs = vmap(f)

xs = jnp.array([0., jnp.pi / 2, jnp.pi])[..., jnp.newaxis]
ys = fs(xs)

print(xs.shape)
print(ys.shape)

(3, 1)
(3, 1)


In [74]:
primals, f_jvp = jax.jvp(f, (xs,), (jnp.ones_like(xs),))
print(primals)
print(f_jvp)


[[ 0.000000e+00]
 [ 1.000000e+00]
 [-8.742278e-08]]
[[ 1.000000e+00]
 [-4.371139e-08]
 [-1.000000e+00]]


In [24]:
key = random.PRNGKey(0)
key, subkey = random.split(key)

In [52]:

key, subkey = jax.random.split(key)
model = eqx.nn.MLP(key=subkey, in_size=1, out_size=1, width_size=20, depth=3, activation=jax.nn.elu)


In [53]:
@eqx.filter_jit
def value_and_jacrev(f, x):
    # Function that computes both the values and the Jacobian using reverse mode autodiff
    y, vjpfun = jax.vjp(f, x)  # Compute the values and obtain a VJP function
    basis = jnp.eye(y.size, dtype=y.dtype)
    jac = vmap(vjpfun)(basis)[0]  # Compute the Jacobian using the VJP function
    return y, jac, vjpfun

@eqx.filter_jit
def value_and_jacobian(f, x):
    y, jac, _ = value_and_jacrev(f, x)
    return y, jac

In [68]:
print(xs)
val, jac, vjpfun = value_and_jacrev(model, jnp.array([0.]))
print(val)
print(jac)
print("-" * 40)
val, jac, vjpfun = value_and_jacrev(model, xs[1])
print(val)
print(jac)

vmap(partial(value_and_jacobian, model))(xs)

[[0.       ]
 [1.5707964]
 [3.1415927]]
[-0.05633701]
[[-0.0524135]]
----------------------------------------
[-0.11803523]
[[-0.03657701]]


(DeviceArray([[-0.05633698],
              [-0.11803527],
              [-0.17688681]], dtype=float32),
 DeviceArray([[[-0.05241352]],
 
              [[-0.03657701]],
 
              [[-0.03539729]]], dtype=float32))

In [55]:
jax.jvp(vjpfun, (jnp.ones_like(val),), (jnp.ones_like(val),))

((DeviceArray([-0.0352001], dtype=float32),),
 (DeviceArray([-0.0352001], dtype=float32),))

In [42]:
class linear(eqx.Module):
    weight: Array
    bias: Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

In [401]:
import jax.numpy as np
def hessian_vector_product_fn(func):
    """Returns a function that efficiently computes the Hessian-vector product of `func`."""
    def hvp_fn(x, v):
        """Computes the Hessian-vector product of `func`."""
        # Define a closure for the function evaluation
        def f_eval(x):
            return func(x)

        # Compute function evaluation using forward mode differentiation
        f_eval_x, vjp_fun = jax.vjp(f_eval, x)
        # Compute gradient using reverse mode differentiation
        grad = vjp_fun(np.ones_like(f_eval_x))[0]
        # Compute Hessian-vector product using forward mode differentiation
        hvp = jax.grad(lambda x: np.dot(grad, v))(x)
        return f_eval_x, grad, hvp

    return hvp_fn

# Example function to compute Hessian-vector product
def example_function(x):
    return np.sum(np.square(x) * x)

# Define inputs and initial vector for Hessian-vector product
x = np.array([1.0, 2.0, 3.0])
v = np.array([0.1, 0.2, 0.3])

# Compute function evaluation, gradient, and Hessian-vector product using JAX
hvp_fn = hessian_vector_product_fn(example_function)
f_eval, grad, hvp = hvp_fn(x, v)

print("Function evaluation:", f_eval)
print("Gradient:", grad)
print("Hessian-vector product:", hvp)

Function evaluation: 36.0
Gradient: [ 3. 12. 27.]
Hessian-vector product: [0. 0. 0.]


In [388]:
def hvp(f, primals, tangents):
    pri, vjp_fun = jax.vjp(f, primals)
    return jax.jvp(vjp_fun, primals, tangents)

In [6]:
def f(X):
  return jnp.sum(jnp.tanh(X)**2)


y, vjp_fun = jax.vjp(f, jnp.array([1., 2., 3.]))
print(y)
print(vjp_fun(jnp.array([1., 1., 1.])))

2.4995086


ValueError: Shape of cotangent input to vjp pullback function (3,) must be the same as the shape of corresponding primal input ().

In [389]:
def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(jax.hessian(f)(X), V, 2)

print(ans1)
print("-" * 80)
print(ans2)
# print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))

TypeError: tanh requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

In [278]:

# def f(x):
#   return jnp.asarray(
#     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
def f(X):
    # return jnp.asarray([jnp.sin(x)])
    return (jnp.sum(jnp.sin(X) ** 2),)



In [267]:
def jacrev_and_vjp(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
           has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
  """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.

  Args:
    fun: Function whose Jacobian is to be computed.
    argnums: Optional, integer or sequence of integers. Specifies which
      positional argument(s) to differentiate with respect to (default ``0``).
    has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
      first element is considered the output of the mathematical function to be
      differentiated and the second element is auxiliary data. Default False.
    holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
      holomorphic. Default False.
    allow_int: Optional, bool. Whether to allow differentiating with
      respect to integer valued inputs. The gradient of an integer input will
      have a trivial vector-space dtype (float0). Default False.

  Returns:
    A function with the same arguments as ``fun``, that evaluates the Jacobian of
    ``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
    then a pair of (jacobian, auxiliary_data) is returned.

  >>> import jax
  >>> import jax.numpy as jnp
  >>>
  >>> def f(x):
  ...   return jnp.asarray(
  ...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
  ...
  >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
  [[ 1.       0.       0.     ]
   [ 0.       0.       5.     ]
   [ 0.      16.      -2.     ]
   [ 1.6209   0.       0.84147]]
  """

  docstr = ("Jacobian of {fun} with respect to positional argument(s) "
            "{argnums}. Takes the same arguments as {fun} but returns the "
            "jacobian of the output with respect to the arguments at "
            "positions {argnums}.")

  def jacfun(*args, **kwargs):
    f = jax.scipy.linalg.lu.wrap_init(fun, kwargs)
    f_partial, dyn_args = argnums_partial(f, argnums, args,
                                          require_static_args_hashable=False)
    tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
    if not has_aux:
      y, pullback = _vjp(f_partial, *dyn_args)
    else:
      y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
    jac = vmap(pullback)(_std_basis(y))
    jac = jac[0] if isinstance(argnums, int) else jac
    example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
    jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
    jac_tree = tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree)
    if not has_aux:
      return jac_tree, _vjp
    else:
      return jac_tree, _vjp, aux

  return jacfun

In [17]:
from jax import jvp, vjp, value_and_grad

alpha = jnp.ones(2) # some jnp.ndarray
test = jnp.ones(2) * 2 

# random data
Q = jnp.array([[1., 0.5], [0.5, 1.]])
b = 3. * jnp.ones(2)
c = 5.

# logic removed for now, but computes a scalar function on Rn -> R
def func(alpha: jnp.ndarray):
    return 0.5 * alpha.T @ Q @ alpha + b.T @ alpha + c 

def _hvp_1(g_f, primals, tangents):
    return jvp(g_f, (primals,), (tangents,))[1]

def _hvp_2(f, primals, tangents):
    return jvp(grad(f), (primals,), (tangents,))[1]

# this does not compute what I'd like
val_1, grad_f_ = jax.vjp(func, alpha)
grad_f_1 = lambda x: grad_f_(x)[0]
grad_val_1 = grad_f_1(jnp.ones(()))
hess_vp_1 = partial(_hvp_1, grad_f_1, alpha)
print(f"grad value 1 = {grad_val_1}")
print(f"not hessvp = {hess_vp_1(test)}")

# grad value 1 = [4.5 4.5]
# not hessvp = [[9. 9.]
#  [9. 9.]]

# this returns what I need, but does extra amount of work of calling `value_and_grad` twice
# once for this call here, and another in the `grad` call inside `_hvp_2`
val_2, grad_val_2 = value_and_grad(func)(alpha)
hess_vp_2 = partial(_hvp_2, func, alpha)
print(f"grad value 2 = {grad_val_2}")
print(f"hessvp = {hess_vp_2(test)}")

# grad value 2 = [4.5 4.5]
# hessvp = [3. 3.]

grad value 1 = [4.5 4.5]


ValueError: Shape of cotangent input to vjp pullback function (2,) must be the same as the shape of corresponding primal input ().

In [16]:
key = random.PRNGKey(0)
key, subkey = random.split(key, 2)
model = eqx.nn.MLP(key=subkey, in_size=3, out_size=1, width_size=20, depth=3, activation=jax.nn.elu)

# jax.grad(model)(jnp.ones((3,)) * jnp.pi / 2)

X = jnp.ones((3,)) * jnp.pi / 2
V = jnp.ones_like(X)
# primals, vjp_fun = jax.vjp(model, X)
# J, = vmap(vjp_fun)(jnp.eye(len(primals)))

# jax.jvp(vjp_fun, (X,), (V,))

hess_vp_1 = jax.jvp(lambda x: jax.vjp(f, X)[1](x)[0],
                (X,), (V,))[1]

ValueError: Shape of cotangent input to vjp pullback function (3,) must be the same as the shape of corresponding primal input ().

In [377]:
key, subkey = random.split(key, 2)
model = eqx.nn.MLP(key=subkey, in_size=3, out_size=1, width_size=20, depth=3, activation=jax.nn.elu)

# X = random.normal(subkey1, (3,4))
# V = random.normal(subkey2, (3,4))
# V = jnp.ones((3,4))

# res = hvp(f, (X,), (V,))

# X = random.normal(subkey1, (3,))
X = jnp.ones((3,)) * jnp.pi / 2
V = jnp.ones_like(X)

# X = X[..., jnp.newaxis]

primals, vjp_fun = jax.vjp(model, X)
J, = vmap(vjp_fun)(jnp.eye(len(primals)))

# hesse_vector
# a, b = jax.jvp(vjp_fun, (J,), (jnp.eye(len(primals)),))

# a = jax.jacfwd(vjp_fun)(jnp.ones((1,)))

a = hvp(model, (jnp.ones(1,),), (jnp.ones(1,),))

# a, b = jax.jvp(vjp_fun, (primals,), (jnp.array([3.0]),))

print("-" * 80)
print(primals.shape)
print(primals)
print(J)
print(a)
print(b)
print("-" * 80)

print(X.shape)
print(res[0].shape)
print(res[1].shape)
# print(res.shape)
print(res)

TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (1,).

In [225]:
jax.jvp(jnp.sin, (jnp.pi / 2, ), (1.0, ))

(DeviceArray(1., dtype=float32, weak_type=True),
 DeviceArray(-4.371139e-08, dtype=float32, weak_type=True))

In [161]:
class MakeScalar(eqx.Module):
    model: eqx.Module

    def __call__(self, *args, **kwargs):
        out = self.model(*args, **kwargs)
        return jnp.reshape(out, ())

model = eqx.nn.MLP(key=subkey, in_size=1, out_size=1, width_size=20, depth=3, activation=jax.nn.elu)
print(model(jnp.array([1.0])))


xs_train = jnp.array([[1.0], [2.0], [3.0]])
print(vmap(model)(xs_train))

print("test")
g_model = grad(MakeScalar(model))

print(g_model(jnp.array([1.0])))
print(vmap(g_model)(xs_train))


test = jax.jacrev(vmap(model))(jnp.array([[1.0], [2.0], [3.0]]))
print(test)

X = xs_train
Y = jnp.ones_like(xs_train)
res = hvp(f, (X, ), (Y, ))
print(res)
# vmap(grad(model))(jnp.array([1.0]))

# primals, f_vjp = eqx.filter_vjp(model, jnp.array([1.0]))

# print(primals)
# jax.jacrev(f)(jnp.array([1.0, 2.0, 3.0]))

# res = jax.jacobian(vmap(model))(jnp.array([0.0, jnp.pi/2, jnp.pi, 2*jnp.pi]))
# res = vmap(jax.jacobian(model))(jnp.array([0.0, jnp.pi/2]))

# print(res.shape)
# print(res)


# model = eqx.nn.MLP(key=subkey, in_size=1, out_size=1, width_size=20, depth=3, activation=jax.nn.elu)

# print(model(jnp.array([1.0])).shape)
# x_test = jnp.array([[1.0], [2.0]])

# vmap(model)(x_test)

[0.20363042]
[[0.20363042]
 [0.20618227]
 [0.20519818]]
test
[0.01606509]
[[ 0.01606509]
 [-0.00216613]
 [ 0.00098517]]
[[[[ 0.01606509]
   [ 0.        ]
   [ 0.        ]]]


 [[[ 0.        ]
   [-0.00216612]
   [ 0.        ]]]


 [[[ 0.        ]
   [ 0.        ]
   [ 0.00098517]]]]


TypeError: Argument '<function f at 0x0000020EB29E5550>' of type <class 'function'> is not a valid JAX type.

In [25]:
def loss_fn(model: eqx.nn.MLP, x: Float[Array, "batch"], y: Float[Array, "batch"]) -> Float[Array, ""]:
    pred_y = vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

In [26]:
vmap(jax.jacrev(model))(xs)

DeviceArray([[[ 0.00450898]],

             [[-0.00172892]],

             [[ 0.00161424]]], dtype=float32)

In [58]:
jax.hessian(f)(jnp.pi)

DeviceArray(8.742278e-08, dtype=float32, weak_type=True)

In [37]:
jnp.sin(jnp.pi)

DeviceArray(-8.742278e-08, dtype=float32, weak_type=True)