Welcome to my JAX tutorial series. This is the third part of this series. As you notice, wherever you look at, you become encountered by something about JAX in machine learning domain, and naturally you become curious about what is really that JAX thing. I am here to explain and clarify your curiosity in the most detailed way. You can find the list of my tutorials below:


**JAX Tutorials:**

* [1. Introduction to JAX](https://www.kaggle.com/code/goktugguvercin/introduction-to-jax)
* [2. Gradients and Jacobians in JAX](https://www.kaggle.com/code/goktugguvercin/gradients-and-jacobians-in-jax)
* [3. Automatic Differentiation in JAX](https://www.kaggle.com/code/goktugguvercin/automatic-differentiation-in-jax)
* [4. Just-In-Time Compilation in JAX](https://www.kaggle.com/code/goktugguvercin/just-in-time-compilation-in-jax)

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="250" height="250">
</div>

In this tutorial, I will try to highlight what automatic differentiation is really, how it works and how it is used to differentiate neural networks. After that, we will look at how the backpropagation of neural networks can be interpreted in terms of differential geometry. At that time, I will explain the most commonly-used jargons *pushforward* and *pullback*. Finally, we will get all of those concepts together with autodiff functions in JAX. It will focus on theoretical side of automatic differentiation, but I will try to provide simple examples and illustrations too. To understand this tutorial, you should absolutely study the previous notebook "Gradients and Jacobians in JAX"

In [1]:
import jax
import jax.numpy as jnp
from jax import random

# Automatic Differentiation:

Automatic differentiation, also called as algorithmic differentiation, computational differentiation or just autodiff, is a set of methods developed to differentiate the mathematical functions in computer programs. While doing this, it takes advantage of the symbolic rules and it is directly capable of computing numerical values of derivatives instead of producing their corresponding mathematical expressions. Hence, it is commonly confused with **numerical** and **symbolic differentiation**, yet it differs from both of those. 

The general idea lying behind wow automatic differentiation works is actually based on primitive mathematical operations like binary arithmetics, trigonometric functions, logarithms and exponentials. Every computer program is actually nothing but the composition of these elementary operations, and what autodiff system does is to distill that composition in atomic level to obtain its primitives and then construct a computation graph by them. Each node in this graph stores three things: A mathematical operator that is defined on itself, the result value computed over that operator from given inputs, and the pointers to keep track of its parents which are the following parts of the computer program taking the result value of that node as input. In addition to those, each node is in charge of recording which arguments were passed to itself for its operator. This graph and its nodes enable the autodiff procedure to differentiate target program step by step and combine the differentials by chain rule, since they reveal all the dependencies between input, intermediate results, and output. How this chain rule will be traversed and the gradients will be accumulated yield two different modes of automatic differentiation. Let's we take a close look at both modes over an example function called $L$:

*Target Function:*

$$w = L(t) = j \circ h \circ g \circ f(t) = e^{2t^2} + 3$$

*Primitive Operations:*

$$x = f(t) = t^2$$
$$y = g(x) = 2x$$
$$z = h(y) = e^y$$
$$w = j(z) = z + 3$$

# Forward Accumulation:


In forward accumulation of automatic differentiation, one of the independent variables in target function is chosen to be fixed, and the derivative of function output with respect to the fixed variable is recursively expanded by chain rule. At this point, we choose our sole independent variable, that is $t$, and we need $\partial w$ / $\partial t$ to be expanded. In fact, this expansion is performed by substituting the derivative of each intermediate variable in target function with respect to its corresponding input until reaching the fixed variable $t$. In other words, we would go through the path from $w$ to $t$ for the expansion. When this expansion is completed, it is rewound by multiplying the value of each derivative through the chain rule to compute the entire differentiation. This means that the recursion is backpropagated from $t$ to $w$, and meanwhile the accumulated derivatives of intermediate functions are repeteadly multipled. Technically, since gradients are backpropagated from $t$ to $w$, it is called forward propagation or forward accumulation. Let's we recursively expand $\partial w$ / $\partial t$ with forward accumulation:


$$\frac{\partial w}{\partial t} = \frac{\partial w}{\partial z} \frac{\partial z}{\partial t}$$

$$\frac{\partial w}{\partial t} = \frac{\partial w}{\partial z} (\frac{\partial z}{\partial y} \frac{\partial y}{\partial t})$$

$$\frac{\partial w}{\partial t} = \frac{\partial w}{\partial z} (\frac{\partial z}{\partial y} (\frac{\partial y}{\partial x} \frac{\partial x}{\partial t}))$$

What we exactly do at this expansion is just to write each intermediate differential in the chain rule step by step, but while doing this, we actually follow the order $z \rightarrow y \rightarrow x \rightarrow t$. The order of recursive expansion puts the parentheses for the priority of the evaluation, which forces the gradients to be multiplied by one another in reverse order, that is $t \rightarrow x \rightarrow y \rightarrow z$. 



# Reverse Accumulation:


In reverse accumulation of automatic differentiation, single output of target function is chosen to be fixed, and the derivative of it with respect to the independent variables is recursively expanded by chain rule. At this point, we choose the output $w$, and we need $ \partial w$ / $\partial t$ to be expanded. In fact, this expansion is performed by substituting the derivative of each intermediate variable in target function with respect to its corresponding input until reaching the output of target function $w$. In other words, we would go through the path from $t$ to $w$ for the expansion. When this expansion is completed, it is rewound by multiplying the value of each derivative through the chain rule to compute entire differentiation. This means that the recursion is backpropagated from $w$ to $t$, and meanwhile the accumulated derivatives of intermediate functions are repeteadly multiplied. Technically, since gradients are backpropagated from $w$ to $t$, it is called backward propagation or reverse accumulation. Let's we recursively expand $\partial w$ / $\partial t$ with reverse accumulation:

$$\frac{\partial w}{\partial t} = \frac{\partial w}{\partial x} \frac{\partial x}{\partial t}$$

$$\frac{\partial w}{\partial t} = (\frac{\partial w}{\partial y} \frac{\partial y}{\partial x}) \frac{\partial x}{\partial t}$$

$$\frac{\partial w}{\partial t} = ((\frac{\partial w}{\partial z} \frac{\partial z}{\partial y}) \frac{\partial y}{\partial x}) \frac{\partial x}{\partial t}$$

What we exactly do at this expansion is just to write each intermediate differential in the chain rule step by step, but while doing this, we actually follow the order $t \rightarrow x \rightarrow y \rightarrow z$. Recursive expansion is done from input to output, so the parentheses for the priority of the evaluation are put in a way that the gradients would be multiplied by one another in reverse order, that is $z \rightarrow y \rightarrow x \rightarrow t$.

# Automatic Differentiation for Neural Networks:

As you noticed in many deep learning courses and books, backpropagation is specialized version of reverse autodiff and it is used to differentiate the neural networks. The main reason why reverse mode is specifically preferred against forward mode is to gain extra efficiency. In fact, which accumulation mode is chosen determines the multiplication order and its priority of derivatives through the chain rule. We opt for reverse mode at this point, since the product of jacobians from the end to the beginning of composed function, that is a neural network in our case, enables us to take advantage of ***vector-jacobian product (VJP)***.


![](https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Automatic%20Differentiation%20in%20JAX/Images/reverse%20accumulation.png?raw=true)

On the other hand, if we attempt to use forward accumulation, what we become encountered by is intense matrix-matrix product through the chain rule, since according to the recursive expansion of derivative chain in composed function, the multiplication starts from the beginning (input, the jacobian matrix situated in the rightest position of the chain illustrated above) to the end (output, the gradient vector situated in the leftmost position of the chain illustrated above), which is opposite flow of reverse accumulation. The conversion of matrix-matrix product to vector-matrix product actually happens by the choice of reverse mode, and great computational efficiency due to the asymtotic complexity of matrix multiplication is obtained.

* *At that point, you can ask yourself whether it is possible to make forward accumulation more efficient and applicable to neural networks:*

In fact, by adding a vector to the beginning of this recursive multiplication chain, we can get rid of computational severity caused by matrix-matrix product, and forward accumulation can be applicable with ***jacobian-vector product (JVP)***. However, inserting a new multiplier to the chain rule can change the exact result; to avoid this, the added vector has to be identity (neutral) element of matrix-vector multiplication, which is one-hot encoding vector illustrated below:

![](https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Automatic%20Differentiation%20in%20JAX/Images/forward%20accumulation%20with%20one-hot%20encoding.png?raw=true)

One-hot encoding vectors are actually the columns of identity matrix; choosing one of them and adding it to the beginning of forward accumulation chain make the computation of its autodiff more efficient than its previous version. However, at the end of this multiplication, we come up with a scalar not a vector. In other words, it yields only one component in entire gradient vector of the network, which is determined by the position of $1$ in that one-hot encoder. In this case, we need to apply forward accumulation multiple times to differentiate the network; on each of which, we use another column of identity matrix as one-hot vector, and we end up with different component in network derivative. As a result, even if we make forward accumulation more efficient in terms of multiplication complexity, unfortunately it requires differentiation to be performed more than one time; that's why, reverse mode is more preferable for backpropagation algorithm. 

* *Another interesting question can be when forward accumulation becomes remarkable and ready to be used in an efficient way:*

In fact, the convenience of accumulation modes depends on the structure of your composed function. Since neural networks consume multiple input at a time and produce one scalar output for them like $f: R^n \rightarrow R$, it is so obvious that the derivative of last layer will be a gradient vector, and by using reverse accumulation, we prefer to initiate the multiplication of derivatives as vector-jacobian product. However, in case of a network like $f: R \rightarrow R^n$, the derivative of first layer not last layer will be a gradient vector; hence, it is better to use forward accumulation in order to carry out the multiplication of derivatives as jacobian-vector product (JVP), illustrated below.

![](https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Automatic%20Differentiation%20in%20JAX/Images/forward%20accumulation.png?raw=true)

Tensorflow and PyTorch both opt for reverse accumulation technique, since neural networks are actually scalar-valued functions with multiple inputs; their structure is highly compatible with the efficiency and method of reverse accumulation. On the other hand, as it is depicted, the general pattern of forward accumulation chain is not suitable for the neural networks; network output is not multiple is actually a scalar and since the networks consume multi-dimensional input at a time, differenting them with respect to one scalar input at a time requires the repetition of forward accumulation for each scalar in input. Tensorflow, additionally gives the support for [forward accumulator](https://www.tensorflow.org/api_docs/python/tf/autodiff/ForwardAccumulator) and makes nice comparison about them. You can take a brief look at it.

# Pushforward and Pullback:

As you noticed, when looking at the differentiation of neural networks from calculus perspective, what we see is the multiplication of gradient vector by a series of jacobians. However, its geometrical interpretation has more distinctive characteristics. In differential geometry, these jacobian-vector and vector-jacobian products at the center of all gradient propagation process actually refer to pushforward and pullback operations:


**Tangent Vectors and Tangent Space:** 

As you remember from previous notebook, partial derivative of a multivariate function like $f: R^2 \rightarrow R$ is computed with respect to each input variable individually, and at that time, the other input variables are considered as constant. What we do by differentiating multivariate functions with this manner is to measure the rate of change at a point $x$ in $R^2$ through the direction of canonical basis vectors in $R^2$. Hence, it is common to see that partial derivative is actually special case of directional derivative. What if we want to generalize this approach and differentiate our multivariate function $f$ along a random vector $v$ like $(1, 2)$ instead of canonical basis? This makes $v=(1, 2)$ a tangent vector passing through the point where the function f will be differentiated, and the assemble of all those possible tangent vectors along which the function f can be differentiable defines the tangent space denoted as $T_x(R^2)$.

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://upload.wikimedia.org/wikipedia/commons/e/e7/Tangentialvektor.svg" width="400" height="250">
</div>

**Pushforward:**

Vector-valued functions enable the transition between different euclidean spaces to happen; they can map an item or data in $N$ dimensional space to its $M$ dimensional version. What those functions exactly do is actually the transformation. This lays the foundation of many machine learning and deep learning methods like principal component analysis, fully-connected layers, and autoencoders. However, the remarkable thing at this point is the fact that total derivatives of those functions are direct linkage between tangent spaces of distinct euclidean spaces. To illustrate, assume that there exists a smooth vector-valued function declared as $g: R^n \rightarrow R^m$. Total derivative of function $g$ at a point $x$, that is jacobian matrix as it is known, defines a linear mapping from the tangent space of $R^n$ at $x$ to the tangent space of $R^m$. In other words, jacobian matrices behave like one-directional gate between tangent spaces of domain and codomain for our vector-valued function. In some resources, tangent space of codomain ($R^m$) is also called as **cotangent space**. Tangent and cotangent spaces together construct a duality, and each of them is the counterpart of the other one. In this duality, jacobian matrix is used to **push** tangent vectors of $T_x(R^n)$ **forward** to tangent vectors of $T_x(R^m)$. Taking a tangent vector from $T_x(R^n)$ and multiplying it by the jacobian of $g$ give us counterpart tangent vector and what we did at this point is actually nothing but pushforwarding. Let's look at  and examine the illustration of pusforwarding given below:


<div style="width:100%;text-align: center;"> 
<img align=middle src="https://upload.wikimedia.org/wikipedia/commons/3/37/Pushforward.svg" width="500" height="300">
</div>

In this illustration, M and N actually refer to the differentiable manifolds, yet since the manifolds are in the context of differential forms and manifold learning, we can assume them as the euclidean spaces $R^n$ and $R^m$ that we are all generally accustomed to. This provides better and more clear understanding for us. In this case, the function $\phi(x)$ corresponds to $g: R^n \rightarrow R^m$, and the jacobian matrix $J_g(x)$ computed by the differentiation of function $g$ at a point x is used to define **pushforward function** $dg(x, v): T_x(R^n) \rightarrow T_x(R^m)$, also denoted as $g^*$. Pushforward function actually applies *jacobian-vector product* in forward accumulation to the vectors of tangent space to push forward them to cotangent space; that's why JVP is generally described as pusforwarding.

$$v \in T_x(R^n)$$

$$w \in T_x(R^m)$$

$$g^*(x, v) = J_g(x) \cdot v = w$$

**Pullback:**

At this point, you may think that pullback has to be reverse version of pushforward, but actually it is not. These two concepts are not completely opposite to each other. Pushforwarding works with tangent vectors, whereas pullback deals with function forms. With the most simplest and understandable terms, pullback is actually to push a function defined on codomain of $g^*$ to the domain of $g^*$. In other words, if there is a new function called $h: T_x(R^m) \rightarrow R$, the composition of function $h$ with $g^*$ creates a new function $g^* \circ h$ whose domain is transferred from $T(R^m)$ to $T(R^n)$; as a result, function $h$ would be pulled back from cotangent space to tangent space.

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://github.com/GoktugGuvercin/JAX-Tutorials/blob/main/Automatic%20Differentiation%20in%20JAX/Images/pushforward%20and%20pullback.png?raw=true" width="600" height="400">
</div>

The description of pullback may remind you of backpropagation of neural networks, since what backpropagation does is actually nothing more than gradual pull back operation. When the last part of the network (loss function taking m number of output scores and generating a scalar loss value) is differentiated, you obtain a gradient vector in $R^m$. When you multiply this gradient vector by the jacobian of the layer, coming before loss function, $J_f(W) \in R^{mxn}$ during backpropagation, you would resize that gradient vector to $R^n$ in matrix calculus. However, this resizing is actually a pullback move in the aspect of differential geometry. To understand this clearly, we need to replace gradient vector and jacobian matrix by their corresponding functions. In that case, instead of resizing the vector, we actually change (pull-back) the domain of loss function to the domain of function representatives of intermediate layers in the network. Let's go over it step by step.

* Loss function: $L(S): R^m \rightarrow R$
* Last affine transformation layer in the network: $f(W): R^n \rightarrow R^m$
* Derivative of loss function: $\nabla_s L(S) \in R^m$
* Derivative of transformation layer: $J_f(W) \in R^{mxn}$

During backpropagation, these two derivatives are multiplied by each other. To achieve this, vector jacobian product in reverse accumulation is performed. 

* Vector Jacobian Product: $\nabla_s L(S) \cdot J_f(W) = \nabla_W L(S) \in R^n $

If this multiplication is considered as composition of their corresponding functions, what backpropagation actually does is to pull back loss function from $R^m$ to $R^n$. First replace the derivatives in that multiplication by their functions, and then reverse-compose them instead of multiplying them:

* Pullback: $f(W) \circ L(S): R^n \rightarrow R$

That's why vector-jacobian product is commonly called as pullback. However, if we talk about autodiff, pushforward and pullback are nothing more than technical jargon; there is no need of getting into so much details of them. The main reason why I explained these two operations is to give small intiution how they look like JVP and VJP as well as to make you more familiar with them, since all those pieces of information help to understand JAX documentation and the arguments passed to [JVP](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) and [VJP](https://jax.readthedocs.io/en/latest/_autosummary/jax.vjp.html#jax.vjp). 

# Jacobian-Vector Product (JVP) in JAX:

Forward mode autodiff follows right-associative principle to process all accumulated gradients of composite function through the chain rule; hence, the gradient vector at the beginning of derivative chain for a function like $f: R \rightarrow R^m$ is repeteadly multiplied by the subsequent jacobians in that chain. At this point, each multiplication is actually nothing but a distinct jacobian-vector product modelled by the function $jax.jvp()$ in JAX. This function takes three input arguments:

* The function to be differentiated
* Input arguments of that function (primals)
* Tangent vectors of input domain (tangents)

What $jax.jvp()$ aims to do with these input arguments is to multiply the derivative of the function evaluated at its given input values by tangent vectors in order to map them to their corresponding cotangent counterparts. At this point, while the derivative refers to the jacobian, the tangent to be multiplied by that derivative can be considered as gradient vector. In that way, one step of chain rule is carried out with right associative characteristices of forward accumulation. 

However, each primal and tangent vector passed to $jax.jvp()$ are not considered individiually. Instead, they are combined to come up with only one contangent vector. Let's assume that we have a function $f(W, b) = A \cdot W^T + b$ for $W \in R^5$ and $b \in R$. In this case, this function does not have two separate domains $R^5$ and $R$, instead the domains of all input variables are combined to define entire function domain $R^6$. This is what $jax.jvp()$ exactly does. Hence, pytree structure of all input arguments and corresponding tangents are concatenated to construct just one primal and one tangent. In that way, function domain and the tangent vector from its tangent space are defined by JAX as a whole, and only one jacobian-vector product is carried out over that condensed primal tangent pair. You can look at the [discussion](https://github.com/google/jax/discussions/10227) that I opened and examine source code expressed in the discussion to have better understanding.


You can also find the list of all automatic differentiation functions in JAX [here](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation).

In [2]:
key = random.PRNGKey(0)
W_key, b_key, input_key, v1_key, v2_key = random.split(key, 5)

b = random.normal(b_key, (1,))  # bias in R
W = random.normal(W_key, (5,))  # weights in R^5
x = random.normal(input_key, (1,5))  # 1 sample, 5 features per sample

tan_vec1 = random.normal(v1_key, W.shape)  # tangent vector from tangent space of R^5
tan_vec2 = random.normal(v2_key, b.shape)  # tangent vector from tangent space of R

single_neuron = lambda W, b: x @ W + b  # model1 
single_neuron2 = lambda params: x @ params[:-1] + params[-1]  # model2

"""
* These two models are actually same, they will be fed by same primals (parameters)
* However, model2 takes concatenated and flattened version of those primals
* Then, we manually compute jacobian of model2 and multiply it by concatenated and flattened version of tangents
* If this manual jvp computation is same as what jax.jnp returns for model1, the concept lying behind jax.jvp()
  that I explained above is confirmed """

y, cotan_vecs = jax.jvp(single_neuron, [W, b], [tan_vec1, tan_vec2])  # cotan_vecs = J * tan_vecs

con_params = jnp.concatenate((W, b)) # concatenated primal in R^6
con_tangents = jnp.concatenate((tan_vec1, tan_vec2)) # concatennated tangent in R^6

jacobian_fun = jax.jacrev(single_neuron2)
jacobian = jacobian_fun(con_params)

print("Manual JVP of model 2: ", jacobian @ con_tangents)
print("What jax.jvp() returns for model 1: ", cotan_vecs)

Manual JVP of model 2:  [-0.83545744]
What jax.jvp() returns for model 1:  [-0.8354575]


# Vector-Jacobian Product (VJP) in JAX:

Vector-jacobian product, which is core part of reverse mode autodiff, represents the repetitive multiplications of accumulated gradients from left-associative perspective, and $jax.vjp()$ is designed to point out this product. Its working structure is a little bit different from $jax.jvp()$ in that it does not directly multiply cotangent vectors by the derivative of the function to be differentiated unlike $jax.jvp()$ instead it returns a reusable function like a transformer, which can be fed by any cotangent vector to be mapped to its counterpart tangents. Let's look at a simple example to get better understanding about how $jax.vjp()$ is used:

In [3]:
# Let's we use model1 function introduced just above
# jvp directly performs computation and maps tangents to cotangents
# vjp returns a reusable function to perform this operation
# Let's pull back cotangents computed by jvp to get their tangent counterparts again

y, vjp_fn = jax.vjp(single_neuron, W, b)  # returning y = f(W, b) and and the function in charge of doing vjp
tan_vecs = vjp_fn(cotan_vecs)  # tan_vecs = transpose(cotan_vecs) * Jacobian

# Let's design our own vjp_fn explicitly
# To accomplish this, we have to build jacobian matrix

jacobian_fun = jax.jacrev(single_neuron, argnums=(0, 1))
jacobian1, jacobian2 = jacobian_fun(W, b)  # jacobian for W and jacobian for b (1x5 and 1x1)

# Remember from jvp example above !
# Function does not work in two different input domains (R^5 and R), actually work in R^6
# So, jacobians w.r.t. both input are combined
full_jacobian = jnp.concatenate((jacobian1, jacobian2), axis=1) # (1x5 o 1x1 = 1x6)

print("Tangent vectors obtained by jax.vjp pulling back cotangent vectors: ", tan_vecs)
print("\nTangent vectors obtained by explicit vector-jacobian product: ", cotan_vecs @ full_jacobian)

Tangent vectors obtained by jax.vjp pulling back cotangent vectors:  (DeviceArray([-0.05531598,  1.8866304 ,  1.1591966 ,  0.15120156,
             -0.08268344], dtype=float32), DeviceArray([-0.8354575], dtype=float32))

Tangent vectors obtained by explicit vector-jacobian product:  [-0.05531598  1.8866304   1.1591966   0.15120156 -0.08268344 -0.8354575 ]


You probably noticed that while $vjp$ provides us with a function to pull cotangent vectors back, $jvp$ directly performs the computation (pushforwarding) and returns the result. In fact, there are two different approaches adapted for the design of those functions. Let's examine them in detail: 

If we have multiple tangent vectors that we want to push forward at same differential point, we need to call $jvp()$ many times, for each of which our original composite function like a neural network $ Loss = f(x, y) = L_n ... L_3 \circ L_2 \circ L_1(x) $ is reevaluated again and again. On the other hand, calling $vjp()$ only once for the multiple cotangent vectors to be mapped to tangent space is enough, since instead of doing this mapping by itself, it delegates this task to an anonymous function that it returns, which is why I call it as reusable function just above. In that way, $vjp()$ does not need to be invoked many times, and also intermediate and final result of that composite function would be computed only once. However, the main tradeoff to have such an advantage is to store all linearization points (the cost of memory). Two conceptual questions appear to clarify the entire process at this point.

***1. What are the linearization points ?***

Linearization points are the ones where the functions are linearly approximated by first-order taylor expansion. To find the derivative of a function at a point $x$, we draw a line tangent to the function at $x$, and we compute the slope of it. That is what first-order taylor expansion actually does; it approximates our target function linearly, and slope of that line gives us the amount of change in the output of approximated function per variation in its input. 

***2. How is a function linearized ?***

To linearize a function (approximating it linearly), JAX introduces the function [$jax.linearize()$](https://jax.readthedocs.io/en/latest/_autosummary/jax.linearize.html), which takes the function to be differentiated as an input along with its primals and put them into partial evaluation machinery. To accomplish this, $jax.linearize()$ actually calls $ad.linearize()$, which is a part of $jax.interpreters$.   What it is exactly done during this process is first of all, to conduct a couple of preliminary operations like flattening that function to the one that can be fed by jaxpr-typed arguments and creating partial input values as a composition of primals and their corresponding tangent abstracts. Then, $ad.linearize()$ passes them to $trace\_to\_jaxpr()$ function, which is the center where partial evaluation is exactly carried out. This function evaluates our flattened composite function, computes its partial output and creates a [$jaxpr$](https://jax.readthedocs.io/en/latest/jaxpr.html) object as a representative for actual computational content of the composite function that we actually try to linearize. In that way, linearization would be completed. To get better understanding about this process, you can look at the source code of [$ad.linearize()$](https://github.com/google/jax/blob/850e8a756a78388605745d241a7f25daa371a23b/jax/interpreters/ad.py#L81) and [$partial\_eval.trace\_to\_jaxpr()$](https://github.com/google/jax/blob/main/jax/interpreters/partial_eval.py#L576). 

All the values computed by $trace\_to\_jaxpr()$, with the most simplest terms, are the intermediate results and final result of $ Loss = f(x, y) = L_n ... L_3 \circ L_2 \circ L_1(x) $. Since these values are recorded in the scope of linearization operation, they can be easily used to compute the derivative of $L_n \,\,\, \forall n \in N$, and finally define the reusable function $vjp\_fn$ returned by $jax.vjp()$. 

Linearization enables all results and derivatives to be pre-computed and then efficiently used over function closure to define reusable function $vjp\_fn()$. If we have a couple of cotangent vectors, we can directly pass them to $vjp\_fn()$ and all those values will not be computed again and again unlike $jax.jvp()$. This alleviates computational cost, but recording those values to make them usable for function closure increases memory consumption. Hence, we can say that $jax.jvp()$ carries out its operations in less memory.

Another interesting but also notable point is the relationship between $jax.jvp()$ and $jax.linearize()$. Linearization procedure, in fact, is carried out over $jvp()$. In other words, the differentiation of composite function at all tangent points is performed by jacobian-vector product; that's why, $jax.linearize()$ and $jax.jvp()$ actually have same behavior, yet since linearization takes a record of all computations and derivatives, it does not computes same thing again and again. At this point, you may be confused about the sttucture of $jax.vjp()$: 

***3. If linearization procedure is carried out over jacobian-vector product, how can $jax.vjp()$ utilizes the linearization ?***

In fact, $jax.vjp()$ conducts all operations over jacobian-vector product, but then it takes transpose of its jaxpr defined to represent the computations in order to convert the context of multiplications into vector-jacobian product. You can think all the picture like that:

* linearization = jvp $+$ partial evaluation machinery
* vjp = jvp $+$ partial evaluation machinery $+$ transpose

While studying on autodiff functions $jvp$ and $vjp$ in JAX, I was confused about some details. To figure out the main idea lying behind them, I opened a [discussion](https://github.com/google/jax/discussions/10271#discussioncomment-2568196) a couple of weeks ago. You can read JAX maintainer Matthew Johnson's the explanation and tips on that discussion. I am really grateful to him; he helped me to fill out all the missing parts in my mind. In addition to this, There is another good [discussion](https://github.com/google/jax/issues/526) about these topics, which abounds with remarkable pieces of knowledge written Matthew Johnson. You can also check it to have clear understanding. 

# References:

**Automatic Differentiation:**

* https://en.wikipedia.org/wiki/Automatic_differentiation
* https://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/slides/lec10.pdf
* https://github.com/tensorflow/tensorflow/issues/675
* https://stackoverflow.com/questions/43022091/where-in-tensorflow-gradients-is-the-sum-over-the-elements-of-y-made
* http://theoryandpractice.org/stats-ds-book/autodiff-tutorial.html
* https://matt-graham.github.io/slides/ad/index.html#/
* https://www.cs.ubc.ca/~fwood/CS340/lectures/AD1.pdf
* https://scholar.princeton.edu/sites/default/files/ast558_seminar_tutorial_on_automatic_differentiation-2.pdf
* http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/
* https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
* https://yaroslavvb.medium.com/backprop-and-systolic-arrays-24e925d2050

**Pushforward and Pullback:**

* https://en.wikipedia.org/wiki/Tangent_space
* https://en.wikipedia.org/wiki/Cotangent_space
* https://en.wikipedia.org/wiki/Pushforward_(differential)
* https://en.wikipedia.org/wiki/Pullback_(differential_geometry)
* https://www.usu.edu/math/fels/complete.pdf
* https://www.quora.com/Why-is-a-differential-referred-to-as-“push-forward”
* https://www.mathphysicsbook.com/mathematics/manifolds/mapping-manifolds/the-differential-and-pullback/
* https://math.stackexchange.com/questions/1189712/geometric-intuition-behind-pullback

**JVP and VJP in JAX:**

* https://jax.readthedocs.io/en/latest/autodidax.html#part-1-transformations-as-interpreters-standard-evaluation-jvp-and-vmap
* https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp
* https://jax.readthedocs.io/en/latest/_autosummary/jax.vjp.html#jax.vjp
* https://jax.readthedocs.io/en/latest/_autosummary/jax.linearize.html#jax.linearize
* https://jax.readthedocs.io/en/latest/jaxpr.html
* https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html

**Discussions:**

* https://github.com/google/jax/discussions/10227
* https://github.com/google/jax/discussions/10271#discussioncomment-2568196
* https://github.com/google/jax/issues/526