# PyTorch vs. Jax. vs. TensorFlow
Computing the gradient with three machine-learning toolkits

In [1]:
import torch
import jax
import tensorflow as tf
import math

A 3D curve:
$$
z = x^2 + xy + y^2
$$
Its gradient:
$$
\begin{array}{lcl}
\frac{\partial z}{\partial x} &=& 2x + y\\
\frac{\partial z}{\partial y} &=& x + 2y\\
\end{array}
$$

The Python function

In [4]:
def f(x, y):
    return x**2 + x * y + y**2

In [5]:
f(3, 4)

37

Manual computation of the gradient in $(3, 4)$:
$$
\begin{array}{lcl}
\nabla f(3, 4) &=& (2 \times 3 + 4, 3 + 2 \times 4)\\
&=& (10, 11)
\end{array}
$$

## Jax
We compute the gradient with Jax. The most intuitive framework

In [6]:
xj = 3.0
yj = 4.0

In [7]:
dzj_dxj = jax.grad(f, argnums=0)
dzj_dyj = jax.grad(f, argnums=1)
dzj_dxjyj = jax.grad(f, argnums=(0, 1))

In [8]:
dzj_dxj(xj, yj)

Array(10., dtype=float32, weak_type=True)

In [9]:
dzj_dyj(xj, yj)

Array(11., dtype=float32, weak_type=True)

In [7]:
dzj_dxjyj(xj, yj)

(Array(10., dtype=float32, weak_type=True),
 Array(11., dtype=float32, weak_type=True))

## PyTorch

In [10]:
xt = torch.tensor(3.0, requires_grad=True)
yt = torch.tensor(4.0, requires_grad=True)

In [11]:
zt = f(xt, yt)
zt

tensor(37., grad_fn=<AddBackward0>)

In [12]:
zt.backward()

In [13]:
zt

tensor(37., grad_fn=<AddBackward0>)

In [14]:
zt.grad_fn

<AddBackward0 at 0x199ad0e50>

In [15]:
zt.grad_fn.next_functions

((<AddBackward0 at 0x199ad0c70>, 0), (<PowBackward0 at 0x199ad0c10>, 0))

In [16]:
zt.grad_fn.next_functions[0][0].next_functions

((<PowBackward0 at 0x199ad0c40>, 0), (<MulBackward0 at 0x199ad0dc0>, 0))

In [17]:
zt.grad_fn.next_functions[0][0].next_functions[0][0].next_functions

((<AccumulateGrad at 0x199ad0df0>, 0),)

In [16]:
zt.grad_fn.next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions

()

In [18]:
xt.grad, yt.grad

(tensor(10.), tensor(11.))

## TensorFlow
The most convoluted

In [19]:
xtf = tf.constant(3.0)
ytf = tf.constant(4.0)

In [20]:
with tf.GradientTape(persistent=True) as tape:
    tape.watch(xtf)
    tape.watch(ytf)
    ztf = f(xtf, ytf)
    

In [21]:
tape.gradient(ztf, xtf)

<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

In [22]:
tape.gradient(ztf, ytf)

<tf.Tensor: shape=(), dtype=float32, numpy=11.0>