code from https://pytorch.org/docs/stable/func.whirlwind_tour.html

In [12]:
import torch
from torch.func import vjp, jvp, grad, vmap, hessian

# Stateful vs Functional API

In [7]:
W = torch.randn(1, 2, requires_grad=True)
x = torch.randn(2, requires_grad=True)

def f(W):
    return ((W ** 2)@x).sum()

In [8]:
W.grad = None  # reset gradient (optimizer.zero_grad)
l = f(W)
l.backward()
print("stateful API: ", W.grad.tolist())

stateful API:  [[0.16340939700603485, 0.005695031955838203]]


In [9]:
gw = grad(f)(W)
print("functional API: ", gw.tolist())

functional API:  [[0.16340939700603485, 0.005695031955838203]]


# Backprop over gradient operations

we are going to see some toy examples of gradient computations.
Here are the equations of the function and its derivatives so that you can check the results are correct:

\begin{aligned}
& y=w^2 \cdot x=\left[\begin{array}{l}
w_1^2 x_1 \\
w_1^2 x_2
\end{array}\right] \\
& \frac{\partial y}{\partial w}=2 w x=\left[\begin{array}{c}
2 w_1 x \\
2 w_2 x
\end{array}\right] \\
& \frac{\partial y}{\partial w_i \partial w_j}=2 x=\left[\begin{array}{cc}
2 x_1 & 0 \\
0 & 2 x_2
\end{array}\right]
\end{aligned}

In [10]:
from torch.func import vjp, jvp, grad

W = torch.randn(1, 2, requires_grad=True)
x = torch.randn(2, requires_grad=True)
ones = torch.tensor([1.0, 1.0]).reshape(1, 2)

foo = lambda W: ((W ** 2)@x).sum()
gw = grad(foo)(W)
print("dfoo/dW:\t", gw.tolist())
print("expect: \t", (2*W*x).tolist())
print()

print("Vector-Jacobian Product")
_, jc = vjp(grad(foo), W)
print("ones @ ddfoo/dW:\t", jc(ones)[0].tolist())
print("expect:          \t", (2*x).tolist())
print()

print("Jacobian-Vector Product")
_, jc = jvp(grad(foo), (W,), (ones,))
print("ddfoo/dW @ ones:\t", jc[0].tolist())
print("expect:          \t", (2*x).tolist())
print()


dfoo/dW:	 [[3.80531907081604, -2.361640691757202]]
expect: 	 [[3.80531907081604, -2.361640691757202]]

Vector-Jacobian Product
ones @ ddfoo/dW:	 [[3.0078017711639404, -2.051379919052124]]
expect:          	 [3.0078017711639404, -2.051379919052124]

Jacobian-Vector Product
ddfoo/dW @ ones:	 [3.0078017711639404, -2.051379919052124]
expect:          	 [3.0078017711639404, -2.051379919052124]



# Backprop over SGD step (full gradient and truncated)

We are going to need the gradient over an SGD step to implement MAML.
- We take the previous example and perform a gradient step
- then, we compute the gradient of the gradient
- there are two options:
    - full gradient: the true gradient
    - truncated gradient: the approximation where $\nabla_{w} y$ is assumed constant when computing the gradient the second time

SGD Step: $\omega^*=\omega-\alpha \nabla_w y=(1-a x) \omega$

model after SGD step: $y^*=\left(w-\alpha \nabla_{w y}\right)^2 x^*=(1-2 \alpha x)^2 w^2 x^*$

Full gradient:
$\nabla_\omega y^*=(1-2 \alpha x)^2 2 \omega x^*$

Truncated Gradient (a.k.a. first-order approximation), which consider the result of the gradient as a constant:
$F O \nabla_w y^*= 2 \omega x^*(1- 2 \alpha x)$

In [11]:
from torch.func import vjp, jvp

W = torch.randn(1, 2, requires_grad=True)
x = torch.randn(2, requires_grad=True)
ones = torch.tensor([1.0, 1.0]).reshape(1, 2)
alpha = 0.1
print("w: ", W.tolist())
print("x: ", x.tolist())
print()

foo = lambda W: ((W ** 2)@x).sum()
gw = grad(foo)(W)
print("dfoo/dW:\t", gw.tolist())
print("expect: \t", (2*W*x).tolist())
print()

def foobar(W):
    gw = grad(foo)(W)
    wbar = W - alpha*gw
    return foo(wbar)

gw = grad(foobar)(W)
print("ddfoo/dW:\t", gw.tolist())
# print("dfoo/dW:\t", jc(ones)[0].tolist())
print("expect: \t", (((1 -2*alpha*x) ** 2) * 2*W*x).tolist())
print()

wbar = W - alpha*grad(foo)(W)
gw = grad(foo)(wbar)
print("FO-ddfoo/dW:\t", gw.tolist())
print("expect: \t", ((1-alpha*2*x) * 2*W*x).tolist())
print()


w:  [[1.0289456844329834, 0.9332138299942017]]
x:  [-0.04671746864914894, 0.5176782011985779]

dfoo/dW:	 [[-0.09613947570323944, 0.9662089347839355]]
expect: 	 [[-0.09613947570323944, 0.9662089347839355]]

ddfoo/dW:	 [[-0.09794442355632782, 0.776492178440094]]
expect: 	 [[-0.09794443100690842, 0.776492178440094]]

FO-ddfoo/dW:	 [[-0.09703775495290756, 0.8661718368530273]]
expect: 	 [[-0.09703775495290756, 0.8661718368530273]]



# vmap
vmap is a vectorization/batching transform.

In [5]:
import torch
from torch.func import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    # note how this model computes the output for a single
    # sample, not the whole minibatch
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)  # auto-vectorization with vmap

result

tensor([0.0000, 3.1131, 0.3854], grad_fn=<ReluBackward0>)

# vmap + grad per-sample gradients (e.g. Fisher)
Combining vmap and grad allows to compute per-sample gradients. Without vmap and grad, we would need to compute the gradient for each sample independently.

In [None]:
from torch.func import vmap
batch_size, feature_size = 3, 5

def model(weights,feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

glog = torch.log(grad_weight_per_example) 
glog.T * glog

tensor([[  2.6212,   1.9154,  -0.2961,  -0.7618,  -6.7760],
        [  0.0000,   0.0000,  -0.0000,   0.0000,   0.0000],
        [  7.0920,  -4.7760,   2.1690,  -6.6394, -14.7211]],
       grad_fn=<MulBackward0>)