Welcome to my JAX tutorial series. This is the second part of this series. As you notice, wherever you look at, you become encountered by something about JAX in machine learning domain, and naturally you become curious about what is really that JAX thing. I am here to explain and clarify your curiosity in the most detailed way. You can find the list of my tutorials below:


**JAX Tutorials:**

* [1. Introduction to JAX](https://www.kaggle.com/code/goktugguvercin/introduction-to-jax)
* [2. Gradients and Jacobians in JAX](https://www.kaggle.com/code/goktugguvercin/gradients-and-jacobians-in-jax)
* [3. Automatic Differentiation in JAX](https://www.kaggle.com/code/goktugguvercin/automatic-differentiation-in-jax)
* [4. Just-In-Time Compilation in JAX](https://www.kaggle.com/code/goktugguvercin/just-in-time-compilation-in-jax)


<div style="width:100%;text-align: center;"> 
<img align=middle src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="250" height="250">
</div>

In this tutorial, I will point out the fundamental mathematical concepts lying behind the optimization of machine learning and deep learning models. First of all, I will explain the theory of *partial derivatives* and *gradient vectors* together with some code implementations by using differentiation functions provided by JAX. Then, we will go through a comparative assessment to figure out how JAX API is different from other autodiff libraries like Tensorflow and PyTorch. After that, we will continue to study on the differentiation with *finite differences*, *total derivatives* and *jacobians*. At this point, to support my explanation, I try to give small example code blocks in JAX. Finally, we investigate what is *hessian matrix* and its role on the *convexity* of multivariate functions. I hope you like it.

In [1]:
import jax
import jax.numpy as jnp

from jax import random
from jax import grad,value_and_grad
from jax.test_util import check_grads

# Partial (Directional) Derivatives and Gradient Vector:

Multivariate functions, which are commonly denoted as $f(x): R^n \rightarrow R$, take multiple values as input, and produce one scalar output value from them. Typical technique used to differentiate a multivariate function in math is to choose one input variable and compute the derivative of that function with respect to chosen variable under the assumption that all other variables are constant. Gradient vector of a multivariate function is a vector valued function which gathers its all partial derivatives as output in a vector form for given $N$ dimensional input coordinates. In the formulas below, you see the declaration and definition of the gradient respectively:

$$\nabla f(x): R^n \rightarrow R^n $$

$$\nabla f(x) = \begin{bmatrix}
    \frac{\partial f(x)}{\partial x_1}, \frac{\partial f(x)}{\partial x_2}, ...,  \frac{\partial f(x)}{\partial x_n} \\ 
\end{bmatrix}^T $$

Jax introduces gradient operator in math as a transformer which takes a python function as input and returns another function which is gradient of given python function. To provide deeper understanding, let's examine it step by step:

* [$jax.grad()$](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is a transformer, that acts like $\nabla$ in calculus
* $f(x)$ is python function that we defined and gave to $jax.grad()$ as input
* $\nabla f(x) = jax.grad(f)$ is gradient of function $f(x)$

At this point, we have to pay our attention to one detail. This transformer only works with the python functions which return one scalar output. Otherwise, we will be encountered by an error. In fact, such a constraint exhibits how JAX adheres to pure math, since according to the calculus rules, gradients are only computable for multivariate functions $f: R^n \rightarrow R$. 

In the example below, a paraboloid in 3D space is defined as a python function. To check the correctness of gradient function returned by $jax.grad()$ transformer for the paraboloid, let's compare it with our explicit gradient function.

In [2]:
# f(x) = x1^2 + x2^2
# paraboloid in 3D space
def paraboloid(x):
    return jnp.sum(x**2)

# gradf(x) = [2x1, 2x2]
# Our explicit gradient function
def g_paraboloid(x):
    return 2 * x 

# JAX's grad operator
grad_paraboloid = grad(paraboloid)

# three different input
input = jnp.array([[0.2, 0.3], [2.4, 3.6], [4.4, 2.1]])
for x in input:
    print("Explicit Gradient Function: ", g_paraboloid(x))
    print("JAX Gradient Function: ", grad_paraboloid(x))
    print("")

Explicit Gradient Function:  [0.4 0.6]
JAX Gradient Function:  [0.4 0.6]

Explicit Gradient Function:  [4.8 7.2]
JAX Gradient Function:  [4.8 7.2]

Explicit Gradient Function:  [8.8 4.2]
JAX Gradient Function:  [8.8 4.2]



Apart from input variables, you may also need to pass the coefficients of those variables to our python function. For example, we want to define a paraboloid in $6$ dimensional space, and each variable per dimension will have its own coefficient, which is illustrated below along with its gradient:

$$ f(x) = 3x_1^2 + 2x_2^2 + 5x_3^2 + x_4^2 + 4x_5^2 $$

$$ \nabla f(x) = \begin{bmatrix}
    6x_1, 4x_2, 10x_3,  2x_4, 8x_5\\ 
\end{bmatrix}^T $$

In that case, our python function will take two different types of input: variables and their corresponding coefficients. **How does $jax.grad()$ know which one of them is variable set ? Is it possible to differentiate the function with respect to the coefficients accidentally ?** The parameter **"argnums"** in $jax.grad()$ enters the picture at that point; it tells which one of those input arguments the function will be differentiated with respect to. When this parameter is not explicitly used, it is set to 0 as default:

In [3]:
# f(x) = 3x1^2 + 2x2^2 + 5x3^2 + x4^2 + 4x5^2
def paraboloid2(coeff, x):
    return jnp.sum(coeff * x**2)

# taking the gradient of paraboloid w.r.t. x 
grad_paraboloid2 = grad(paraboloid2, argnums=(1))

coefficients = jnp.array([3, 2, 5, 1, 4]) # coefficients
input = jnp.array([2., 1., 3., 2., 4.]) # input in R^5


print(grad_paraboloid2(coefficients, input))

[12.  4. 30.  4. 32.]


# General Assessment for Automatic Differentiation of JAX, Tensorflow and PyTorch:

As you know, Tensorflow and PyTorch are the most widely-used deep learning libraries, and they both prefer to use computation graph to differentiate the neural networks in an efficient way. These computation graphs are, in fact, a data structure used to record all mathematical computations relevant to model parameters, and help to compute gradients from loss tensor to all network layers by backpropagation. Although such a technique for differentiation is very useful and effective, it is a little bit far away from the reality of math. What JAX differs from those two dominant libraries is that it directly works with functions themselves during differentiation, and it prefers to strictly adhere to the concepts lying behind calculus. What we do in theory (math) and practice (coding) are parallel with each other and they are strictly correlated in JAX. What we have covered up to that point in this notebook clearly proves this.

At that point, you can ask yourself how difficult differentiating neural networks will be by JAX. In fact, we will define a loss function composed of two operations, which are forward pass in batch level and error computation, and then differentiate it with respect to model parameters. At this point, loss function only takes *parameters* and *input data* as input arguments. Organizing all those parameters of many layers will not result in huge argument list to loss function; JAX comes up with an efficient data structure called **pytrees** to cope with such large and nested data formats. Pytrees will be covered in another notebook, but we can make and go over general concept of pytrees without getting into so much detail by linear regression model:

*Let's we at first introduce our model function and Mean-Squared Loss:*

$$ \hat{y} = f(x, w, b) = xw + b  \,\,\,\,\,\,\,\, \hat{y} \in R^N$$ 

$$ L = MSELoss(\hat{y}, y) = \frac{1}{N} \| \hat{y} - y \|_2^2 $$

*During backpropagation, the gradients of loss with respect to model parameters w and b are computed. Let's we take a closer look at them with chain rule:*

$$ \frac{\partial L}{\partial \hat{y}} = \frac{2}{N} \, (\hat{y} - y) $$

$$ \frac{\partial \hat{y}}{\partial w} = x, \,\,\,\,\, \frac{\partial \hat{y}}{\partial b} = 1 $$

$$ \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial w} = \frac{\partial L}{\partial w} = \frac{2x}{N} \, (\hat{y} - y) $$ 

$$ \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial b} = \frac{2}{N} \, (\hat{y} - y) $$

In forward pass, model prediction is at first computed, and then depending on those predicted values, total error is measured by loss function. In backward pass, the gradient of the loss with respect to model parameters are computed. Jax comes up with a new transformer to generate a function capable of doing both forward and backward pass at the same time, which is $ jax.value\_and\_grad() $. When the name of loss function is passed to this transformer, a new function, that can produce both loss value and gradients as output, is generated. Let's implement forward and backward passes of linear model given above with this transformer:

In [4]:
def random_params(in_features, out_features, key, scale):
    w_key, b_key = random.split(key)
    w = scale *  random.normal(w_key, (in_features, out_features))
    b = scale * random.normal(b_key, (out_features, ))
    return {"W": w, "b": b}  # small pytree

def mse_loss(data, params):
    x, labels = data
    preds = jnp.dot(x, params["W"]) + params["b"] # forward pass
    return (1. / x.size) * jnp.sum((preds - labels)**2) # error computation

# Gradient of mse loss with respect to w
# Explicit gradient computation
# It implements the derivative rule dL/dw given above
def gw_mse_loss(data, params):
    x, labels = data
    preds = jnp.dot(x, params["W"]) + params["b"]
    return (2. / x.size) * (preds - labels) * x

# Gradient of mse loss with respect to b
# Explicit gradient computation
# It implements the derivative rule dL/db given above
def gb_mse_loss(data, params):
    x, labels = data
    preds = jnp.dot(x, params["W"]) + params["b"]
    return (2. / x.size) * (preds - labels)

# Each sample has 3 features; dataset has only 1 sample. So, data format is 1x3
# Each sample has corresponding 1 label value, dataset has only 1 sample. So, label format is 1x1
x = jnp.array([[1.2, 4.5, 2.3]])
labels = jnp.array([[6.5]])

# randomly generating regression parameters
params = random_params(3, 1, random.PRNGKey(0), 0.1)

# Transformer returns a function that can produce both loss and its gradients w.r.t. w and b 
loss_grad_fn = value_and_grad(mse_loss, argnums=(1))
loss, gradients = loss_grad_fn((x, labels), params)

print("Loss value: ", loss)
print("Explicit Gradient Function: ", gw_mse_loss((x, labels), params), gb_mse_loss((x, labels), params))
print("Jax Gradient Function: ", gradients)

Loss value:  14.089611
Explicit Gradient Function:  [[ -5.201159 -19.504345  -9.968887]] [[-4.334299]]
Jax Gradient Function:  {'W': DeviceArray([[ -5.201159],
             [-19.504345],
             [ -9.968887]], dtype=float32), 'b': DeviceArray([-4.334299], dtype=float32)}


In the code above, we not only computed total loss and gradients of our linear regression model at the same time, but also witnessed how JAX performs differentiation with respect to the containers and some nested structures like a small dictionary in our example. What it exactly does is to differentiate the function for each leaf in the container. To organize the parameters of neural networks, we will define an encapsulating container full of sub dictionaries, each of which is dedicated to enclose the parameters of particular part, segment or layer of the network as a pyteee.

# Finite Differences:

Finite differences is a technique used to observe the change on function output with respect to the update in its relevant input, so it underlies the foundation of the limit definition of the derivative to be taught for us to be able to give deeper understanding about the main principle of the derivative. What it exactly means is how much the output of a function tends to change when its given input is minimally updated. It is like function's sensitivity to its input. 

With the help of finite differences, we can actually approximate the derivative of a function and thereby helping to check the correctness of numerical solution of differential equations. There are three types of methods to compute finite differences:

1. Forward Differences: $\Delta_h f(x) = f(x + h) - f(x)$
2. Backward Differences: $\nabla_h f(x) = f(x) - f(x - h)$
3. Central Differences: $\delta_h f(x) = f(x + \frac{h}{2}) - f(x - \frac{h}{2}) = \Delta_{\frac{h}{2}} f(x) + \nabla_{\frac{h}{2}} f(x)$


As you noticed, all those three methods actually give same answer to the question how much changing the input to the function $f(x)$ by $h$ influences its corresponding output; however, the difference between them is actually how the input is changed. At this point, dividing finite differences by the amount of change in input value results in the definition of derivative. 

$$ f'(x) = \displaystyle\lim_{h \rightarrow 0}\frac{f(x+h) - f(x)}{h} = \frac{\Delta_h f(x)}{h} \cong \frac{\nabla_h f(x)}{h} \cong \frac{\delta_h f(x)}{h}  $$

At this point, JAX introduces very handy function called $ jax.test\_util.check\_grads()$ to compute finite differences of the function given as input over very small $h$ value and then to compare the result with real gradients. If these two measurements are not close enough to each oher, which is determined by two different tolerance values, the function raises an error. To test $check\_grads()$ and observe how it will behave when tolerance is exceeded, you can use and larger $h$ value, since as $h$ increases, the sensitivity of finite differences falls. To get deeper understanding about how the function works, you can look at the [documentation](https://github.com/google/jax/pull/2656/commits/7b80574101b1d789bf22b074c808f806e393b37c). Let's we try it over our mse loss in previous code block:

In [5]:
data = (x, labels)
func_args = (data, params)  # input arguments to mse loss
diff_order = 1  # first order differentiation
epsilon = 10**(-3) # h (try with h=0.1)
abs_tol = 10**-2  # abs error between forward_diff/h and gradient should be lower than 0.01

try:
    check_grads(f=mse_loss, args=func_args, order=diff_order, eps=epsilon, atol=abs_tol)
except Exception as ex:
    print(type(ex).__name__, ex)

# Total (Frechet) Derivative and Jacobian Matrix:

Differentiating the functions totally is more general technique, because it tries to approximate target function by all input variables at the same time unlike partial differentiation. Hence, it can be also considered as extended version of directional derivatives. The main reason why such a more generic method is required is to compute derivatives of vector valued functions, denoted as $f: R^n \rightarrow R^m$, in an easier way. 

A vector valued function is represented as a system of aggregated $m$  multivariate functions, each of which is fed by same input vector in $R^n$ space and generates one scalar component of output vector in $R^m$ . Let $g: R^n \rightarrow R^m$ be a vector-valued function which will be differentiated. In that case, function $g$  can be totally differentiable in $R^n$ only if its all multivariate components ($g_1, g_2, g_3, ..., g_m$)  are partially differentiable in $R^n$. Satisfying such a differentiability condition enables total derivative of $g$ to be computed as **jacobian matrix** of $mxn$ dimension whose each row refers to the gradient of one multivariate function $g_i$.


$$ g(x) = \begin{bmatrix}
    g_1(x)\\
    g_2(x)\\
    \vdots \\
    g_m(x)\\
\end{bmatrix} $$

$$ 
Dg(x) = J_g(x) =\begin{bmatrix}
    \partial_{e_1}g_1(x) & \partial_{e_2}g_1(x) & \cdots & \partial_{e_n}g_1(x)\\
    \partial_{e_1}g_2(x) & \partial_{e_2}g_2(x) & \cdots & \partial_{e_n}g_2(x)\\
\vdots & \vdots & \ddots & \vdots\\ 
\partial_{e_1}g_m(x) & \partial_{e_2}g_m(x) & \cdots & \partial_{e_n}g_m(x)
\end{bmatrix} = \begin{bmatrix}
    \text{---} & \nabla g_1(x) & \text{---}\\
    \text{---} & \nabla g_2(x) & \text{---}\\
\vdots & \vdots & \vdots\\ 
\text{---} & \nabla g_m(x) & \text{---}\\
\end{bmatrix} $$

At this point, JAX provides us with two different transformers [$jax.jacfwd()$](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html#jax.jacfwd) and [$jax.jacrev()$](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html#jax.jacrev) to be able to compute jacobian matrices of vector-valued functions: As it is understood from their names, the former uses forward-mode automatic differentation to construct entire jacobian, the latter does same thing with reverse-mode. Naturally, we ask this question to ourselves: 

**If we have two different functions to compute same jacobian, which one do we opt for ?** 

Their computational efficiency tends to change depending on the dimension of domain and codomain. In the case of $m \gg n$, we end up with a tall jacobian, which is computed by $jacfwd()$ in much more efficient way than $jacrev()$. In opposite case, it is obvious that reverse accumulation does a better job than forward mode. For the square jacobians, $jacfwd()$ seems like it has an advantage over $jacrev()$.

Their signature are almost same as $grad()$, which means that they take a function to be totally differentiated as an argument and works as a transformer to return a new function in responsible for computing jacobian for any input. In other words, they don't directly compute jacobian matrices, instead they return the function that can do it for any point where function $g$ will be differentiated. Also note that, as in $grad()$, both of these jacobian transformers enable us to specify which one of input variables the jacobian will be calculated with respect to; the parameter $argnums$ is also usable for them. 

Linear transformations are one type of quite common vector-valued functions, they are generally used to project target data in $R^n$ into lower dimensional space in order to compress it and thereby obtaining its higher level of representation. The opposite case is also applicable during data reconstruction operations like the decoder part of autoencoders. A linear transformation is expressed by the equation $y = f(x) = A \cdot x$, and differentiating it gives us directly the transformation matrix $A$. Let's examine it with the help of code below. 

In [6]:
key = random.PRNGKey(137)
key1, key2 = random.split(key)

input = random.uniform(key1, (10, ))  # input vector in R^10
trans_matrix = random.uniform(key2, (20, 10))  # transformation matrix of shape 20x10 to project it into R^20

def affine_transform(input, matrix):  # transformation function
    return matrix @ input
    
jacobian_fn = jax.jacfwd(affine_transform, argnums=0)  # it returns the function in charge of computing jacobian
jacobian = jacobian_fn(input, trans_matrix)  # y = f(x) = Ax, dy/dx = A
print(jnp.all(trans_matrix == jacobian))

True


# Hessian Matrix:

As mentioned in the early parts of the notebook, multivariate functions like $f: R^n \rightarrow R$ are partially differentiated with respect to each input variable, and composition of partial derivatives along a vector creates the gradient, which is denoted as $\nabla f: R^n \rightarrow R^n$. Gradient is a vector valued function, and if we tend to differentiate it by computing its jacobian, we actually obtain hessian matrix of our multivariate function $f$. Domain and codomain of gradient for any multivariate function is on same euclidean space, which makes hessian always nxn square matrices.

$$\nabla f(x): R^n \rightarrow R^n $$

$$ H_f(x) = J_{\nabla f}(x) =\begin{bmatrix}
    \frac{\partial^2 f(x)}{\partial x_1^2} & \frac{\partial^2 f(x)}{\partial x_1 \partial x_2} & \cdots & \frac{\partial^2 f(x)}{\partial x_1 \partial x_n}\\
    \frac{\partial^2 f(x)}{\partial x_2 \partial x_1} & \frac{\partial^2 f(x)}{\partial x_2^2} & \cdots & \frac{\partial^2 f(x)}{\partial x_2 \partial x_n}\\
\vdots & \vdots & \ddots & \vdots\\ 
\frac{\partial^2 f(x)}{\partial x_n \partial x_1} & \frac{\partial^2 f(x)}{\partial x_n \partial x_2} & \cdots & \frac{\partial^2 f(x)}{\partial x_n^2}
\end{bmatrix} $$

First of all, hessian matrix enables us to approximate target multivariate function $f$ quadratically with the help of second order taylor expansion, which is called Newton's method. In that way, converging to minimum point of $f$ becomes faster than gradient descent optimization. For example, linear regression requires just one iteration to be perfectly optimized by Newton's method. In addition to this, we do not need to search for the optimal learning rate, since inverse of hessian automatically scales the direction of improvement (determines optimal gradient step). You can examine how one iteration Newton's method is carried out:

$$ W_{n+1} := W_n - H_f(x)^{-1} \cdot \nabla f(x) $$

On the other hand, as opposed to gradient descent, Newton's method does not work well everytime. Since it approximates the function quadratically, it determines the direction of optimization depending on the curvature. In other words, Newton's method converges to local maximum point if its local neighborhood is a concave region on function $f$. Hence, random initialization for the model parameters is a risky strategy. Furthermore, the computation of inverse hessian and maintaing it in the memory is quite costly for large machine learning models like neural networks. For example, if k million parameters were accomodated in a neural network, $k^2$ number of items would exists in hessian matrix and computational complexity of inverse hessian is about $O(k^3)$. That is why the usage of Newton's method is generally avoided for deep learning. However, to alleviate these drawbacks, some approaches and techniques like BFGS and Levenberg are also recommended. How Newton's method converges to extremum points is perfectly illustrated with simulations in [Jeremy Watt's Github Tutorials](https://jermwatt.github.io/machine_learning_refined/notes/4_Second_order_methods/4_4_Newtons.html)

Apart from the optimization, hessian matrix also enables us to analyze the convexity of the functions. In that way, convex and concave regions on multi-dimensional $f$ curve can be spotted. To achieve this successfully, hessian matrix of target function $f$ at point $x$ is decomposed into its eigenvalues and the sign of each of those eigenvalues are checked. Their sign determines the function is convex, concave or flat through each of its cordinate axes. How eigenvalues of hessian are interpreted is summarized with the following table:

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Gradients%20and%20Jacobians%20in%20JAX/Images/Eigenvalues%20of%20Hessian%20Matrix.png?raw=true" width="800" height="400">
</div>

To compute hessian matrix in JAX, we do not need additional transformer; the computation of hessian, as a result, relies on jacobian and gradient and JAX introduces necessary transformers for both. Let's examine the functions $f(x,y) = x^2 + y^3$ and $g(x,y) = x^2 - y^2$. These two functions take an input in $R^2$ and produce one scalar output; hence, both of them are multivariate functions. Since input space for both functions is $R^2$, their hessian matrices naturally is of $2x2$ shape, and thereby having 2 eigenvalues. We can easily compute the hessians by hand by just applying total derivative rules:

$$ H_f(x,y) = J_{\nabla f}(x,y) =\begin{bmatrix}
    2 & 0\\
    0 & 6y\\
\end{bmatrix} \,\,\, \rightarrow \lambda_1 = 2, \,\, \lambda_2 = 6y $$

$$ H_g(x,y) = J_{\nabla g}(x,y) =\begin{bmatrix}
    2 & 0\\
    0 & -2\\
\end{bmatrix} \,\,\, \rightarrow \lambda_1 = 2, \,\, \lambda_2 = -2 $$

We can directly read eigenvalues of hessian matrices given above by their diagonal entries without performing any decomposition technique since all other entries are zero. The convexity of function $g$ does not change depending on the location inside the curve; it is fixed: While it is convex through the x axis, it shows concave characteristics on y axis. On the other hand, the curvature of function $f$ lying through second basis vector $e_2$ (axis y) completely depends on the value of input $y$. Their convexity characteristics are also observed on the plots illustrated below:

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Gradients%20and%20Jacobians%20in%20JAX/Images/Function%20Curves.png?raw=true" width="800" height="400">
</div>


Plotting the curves of the functions, as depicted above, directly reveals their convexity, yet if you have more than 2 input variables, the functions automatically overflow onto 4 or higher dimensional space. In this case, the investigation of hessian matrices becomes inevitable. Let's look at how to compute the hessians of those two functions with JAX and compare it with handcrafted hessian matrices in the following code:   

In [7]:
# function f(x, y) is implemeted as f(w) in which w is vector of R^2
# Vectorization of function inputs makes gradient computation easier and more secure
def f(w):
    return w[0]**2 + w[1]**3

# function g(x, y) is implemeted as g(w) in which w is vector of R^2
# Vectorization of function inputs makes gradient computation easier and more secure
def g(w):
    return w[0]**2 - w[1]**2

# manually coded hessian matrix of function f()
def h_f(w):
    return jnp.array([[2.0, 0.0], [0.0, 6*w[1]]])

# manually coded hessian matrix of function g()
def h_g(w):
    return jnp.array([[2.0, 0.0], [0.0, -2.0]])

# hessian matrices of the functions f and g are computed by jax transformers
hessian_f = jax.jacfwd(jax.grad(f))
hessian_g = jax.jacfwd(jax.grad(g))

# two different (x,y) input pairs are prepared to be passed to hessians of functions f and g
input = jnp.array([[2.0, 1.0], [2.4, -5.0]])

print("Hessian Matrix of Function f()\n")
for w in input:
    print("Input: ", w)
    print("Manual hessian:", h_f(w))
    print("Jax hessian: ", hessian_f(w))
    print()

print("\nHessian Matrix of Function g()\n")
for w in input:
    print("Input: ", w)
    print("Manual hessian:", h_g(w))
    print("Jax hessian:", hessian_g(w))
    print()

Hessian Matrix of Function f()

Input:  [2. 1.]
Manual hessian: [[2. 0.]
 [0. 6.]]
Jax hessian:  [[2. 0.]
 [0. 6.]]

Input:  [ 2.4 -5. ]
Manual hessian: [[  2.   0.]
 [  0. -30.]]
Jax hessian:  [[  2.   0.]
 [  0. -30.]]


Hessian Matrix of Function g()

Input:  [2. 1.]
Manual hessian: [[ 2.  0.]
 [ 0. -2.]]
Jax hessian: [[ 2.  0.]
 [ 0. -2.]]

Input:  [ 2.4 -5. ]
Manual hessian: [[ 2.  0.]
 [ 0. -2.]]
Jax hessian: [[ 2.  0.]
 [ 0. -2.]]



In this tutorial, I attempted to highlight the differentiation of multivariate and vector-valued functions. In accordance with this purpose, partial derivatives, total derivatives and finite differences would be investigated. However, how these concepts are deployed in automatic differentiation is unresolved in this tutorial. In fact, it will be the main topic of the next notebook.

# References

* https://en.wikipedia.org/wiki/Gradient
* https://en.wikipedia.org/wiki/Hessian_matrix
* https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad
* https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html
* https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#starting-with-grad