In [1]:
import jax.numpy as jnp
from jax import grad, jacrev, jacfwd, jvp, vjp

In [2]:
def f(x):
    return jnp.array(
        [
            x[0]**6 * x[1]**4 * x[2]**9 * x[3]**2,
            x[0]**2 * x[1]**3 * x[2]**5 * x[3]**3,
            x[0]**5 * x[1]**7 * x[2]**7 * x[3]**6,
        ]
    )

In [3]:
evaluation_point = jnp.array([1., .5, 1.5, 2.])

In [4]:
f(evaluation_point)

Array([9.61084 , 7.59375 , 8.542969], dtype=float32)

In [5]:
full_jac = jacfwd(f)(evaluation_point)

In [6]:
full_jac

Array([[ 57.66504 ,  76.88672 ,  57.66504 ,   9.61084 ],
       [ 15.1875  ,  45.5625  ,  25.3125  ,  11.390625],
       [ 42.714844, 119.60156 ,  39.867188,  25.628906]], dtype=float32)

In [7]:
multiplication_point = jnp.array([0.2, 0.3, 0.4, 0.8])

In [8]:
full_jac @ multiplication_point

Array([65.353714, 35.94375 , 80.87344 ], dtype=float32)

In [9]:
primals_out, tangents_out = jvp(f, (evaluation_point,), (multiplication_point,))

In [10]:
primals_out, tangents_out

(Array([9.61084 , 7.59375 , 8.542969], dtype=float32),
 Array([65.353714, 35.943752, 80.87344 ], dtype=float32))

In [11]:
full_jac = jacrev(f)(evaluation_point)

In [12]:
full_jac

Array([[ 57.66504 ,  76.88672 ,  57.66504 ,   9.61084 ],
       [ 15.1875  ,  45.5625  ,  25.3125  ,  11.390625],
       [ 42.714844, 119.60156 ,  39.867188,  25.628906]], dtype=float32)

In [13]:
multiplication_point[:-1].T @ full_jac

Array([33.175198, 76.88672 , 35.073635, 15.590919], dtype=float32)

In [14]:
primals_out, vjp_fun = vjp(f, evaluation_point)

In [15]:
vjp_fun(multiplication_point[:-1])

(Array([33.175194, 76.88672 , 35.073635, 15.590918], dtype=float32),)