Skip to content

Commit

Permalink
Add make_jvp and make_ggnvp convenience wrappers. (#237)
Browse files Browse the repository at this point in the history
Also change the names inside make_vjp to be more consistent with vjp usage
elsewhere.
  • Loading branch information
mattjj committed Jun 8, 2017
1 parent 3a84498 commit a055b42
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 39 deletions.
8 changes: 5 additions & 3 deletions autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from . import container_types
from .container_types import make_tuple, make_list, make_dict
from .convenience_wrappers import (grad, multigrad, multigrad_dict, elementwise_grad,
value_and_grad, grad_and_aux, hessian_vector_product,
hessian, jacobian, vector_jacobian_product, grad_named,
checkpoint, make_hvp, value_and_multigrad)
value_and_grad, grad_and_aux, hessian_tensor_product,
hessian_vector_product, hessian, jacobian,
tensor_jacobian_product, vector_jacobian_product,
grad_named, checkpoint, make_hvp, value_and_multigrad,
make_jvp, make_ggnvp)
59 changes: 38 additions & 21 deletions autograd/convenience_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,45 +75,62 @@ def multigrad_fun(*args, **kwargs):
return double_val_fun(*args, **kwargs)[1]
return multigrad_fun

def elementwise_grad(fun, argnum=0):
"""Like `jacobian`, but produces a function which computes just the diagonal
of the Jacobian, and does the computation in one pass rather than in a loop.
Note: this is only valid if the Jacobian is diagonal. Only arrays are
currently supported. Can be used for broadcasting."""
def sum_output(*args, **kwargs):
return np.sum(fun(*args, **kwargs))
return grad(sum_output, argnum=argnum)
elementwise_grad = grad # backward compatibility

def hessian(fun, argnum=0):
"Returns a function that computes the exact Hessian."
return jacobian(jacobian(fun, argnum), argnum)

def make_hvp(fun, argnum=0):
"""Constructs a function for evaluating the Hessian-vector product at a
point, which may be useful when evaluating many Hessian-vector products at
the same point while caching the results of the forward pass."""
"""Builds a function for evaluating the Hessian-vector product at a point,
which may be useful when evaluating many Hessian-vector products at the same
point while caching the results of the forward pass."""
def hvp_maker(*args, **kwargs):
return make_vjp(grad(fun, argnum), argnum)(*args, **kwargs)[0]
return hvp_maker

def hessian_vector_product(fun, argnum=0):
"""Builds a function that returns the exact Hessian-vector product.
The returned function has arguments (*args, vector, **kwargs), and takes
roughly 4x as long to evaluate as the original function."""
def hessian_tensor_product(fun, argnum=0):
"""Builds a function that returns the exact Hessian-tensor product.
The returned function has arguments (*args, tensor, **kwargs), and for
vectors takes roughly 4x as long to evaluate as the original function."""
fun_grad = grad(fun, argnum)
def vector_dot_grad(*args, **kwargs):
args, vector = args[:-1], args[-1]
return np.tensordot(fun_grad(*args, **kwargs), vector, np.ndim(vector))
return grad(vector_dot_grad, argnum) # Grad wrt original input.
return grad(vector_dot_grad, argnum)
hessian_vector_product = hessian_tensor_product

def vector_jacobian_product(fun, argnum=0):
"""Builds a function that returns the exact vector-Jacobian product, that
is the Jacobian matrix left-multiplied by vector. The returned function
has arguments (*args, vector, **kwargs)."""
def tensor_jacobian_product(fun, argnum=0):
"""Builds a function that returns the exact tensor-Jacobian product, that
is the Jacobian matrix left-multiplied by tensor. The returned function
has arguments (*args, tensor, **kwargs)."""
def vector_dot_fun(*args, **kwargs):
args, vector = args[:-1], args[-1]
return np.tensordot(vector, fun(*args, **kwargs), axes=np.ndim(vector))
return jacobian(vector_dot_fun, argnum) # Grad wrt original input.
return jacobian(vector_dot_fun, argnum)
vector_jacobian_product = tensor_jacobian_product

def make_jvp(fun, argnum=0):
"""Builds a function for evaluating the Jacobian-vector product at a
point. Roughly 1.5x more FLOPs than forward-mode, plus memory requirements
that scale with the number of primitives applied in the evaluation of f, as
well as other overheads. See github.com/BB-UCL/autograd-forward."""
def jvp_maker(*args, **kwargs):
vjp, y = make_vjp(fun, argnum)(*args, **kwargs)
vjp_vjp, _ = make_vjp(vjp)(vspace(getval(y)).zeros())
return vjp_vjp # vjp_vjp is just jvp by linearity
return jvp_maker

def make_ggnvp(f, g=lambda x: 1./2*np.sum(x**2, axis=-1), f_argnum=0):
"""Builds a function for evaluating generalized-Gauss-Newton-vector products
at a point. Slightly more expensive than mixed-mode."""
def ggnvp_maker(*args, **kwargs):
f_vjp, f_x = make_vjp(f, f_argnum)(*args, **kwargs)
g_hvp, grad_g_x = make_vjp(grad(g))(f_x)
f_vjp_vjp, _ = make_vjp(f_vjp)(vspace(getval(grad_g_x)).zeros())
def ggnvp(v): return f_vjp(g_hvp(f_vjp_vjp(v)))
return ggnvp
return ggnvp_maker

def value_and_grad(fun, argnum=0):
"""Returns a function that returns both value and gradient. Suitable for use
Expand Down
10 changes: 6 additions & 4 deletions autograd/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from .errors import defgrad_deprecated

def make_vjp(fun, argnum=0):
def vjp(*args, **kwargs):
def vjp_maker(*args, **kwargs):
start_node, end_node = forward_pass(fun, args, kwargs, argnum)
if not isnode(end_node) or start_node not in end_node.progenitors:
warnings.warn("Output seems independent of input.")
return lambda g : start_node.vspace.zeros(), end_node
return lambda g : backward_pass(g, end_node, start_node), end_node
return vjp
def vjp(g): return start_node.vspace.zeros()
else:
def vjp(g): return backward_pass(g, end_node, start_node)
return vjp, end_node
return vjp_maker

def forward_pass(fun, args, kwargs, argnum=0):
args = list(args)
Expand Down
76 changes: 65 additions & 11 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import autograd.numpy.random as npr
from autograd.util import *
from autograd import (grad, elementwise_grad, jacobian, value_and_grad,
grad_and_aux, hessian_vector_product, hessian, make_hvp,
multigrad, jacobian, vector_jacobian_product, primitive,
checkpoint, value_and_multigrad)
grad_and_aux, hessian_tensor_product, hessian, make_hvp,
multigrad, jacobian, tensor_jacobian_product, primitive,
checkpoint, value_and_multigrad, make_jvp, make_ggnvp)
from builtins import range

npr.seed(1)
Expand Down Expand Up @@ -112,12 +112,12 @@ def simple_fun(a, b):
numeric = np.squeeze(np.array([nd(simple_fun, A, B[i])[argnum] for i in range(len(B))]))
check_equivalent(exact, numeric)

def test_hessian_vector_product():
def test_hessian_tensor_product():
fun = lambda a: np.sum(np.sin(a))
a = npr.randn(5)
v = npr.randn(5)
H = hessian(fun)(a)
check_equivalent(np.dot(H, v), hessian_vector_product(fun)(a, v))
check_equivalent(np.dot(H, v), hessian_tensor_product(fun)(a, v))

def test_hvp():
fun = lambda a: np.sum(np.sin(a))
Expand All @@ -132,36 +132,36 @@ def test_hessian_matrix_product():
a = npr.randn(5, 4)
V = npr.randn(5, 4)
H = hessian(fun)(a)
check_equivalent(np.tensordot(H, V), hessian_vector_product(fun)(a, V))
check_equivalent(np.tensordot(H, V), hessian_tensor_product(fun)(a, V))

def test_hessian_tensor_product():
fun = lambda a: np.sum(np.sin(a))
a = npr.randn(5, 4, 3)
V = npr.randn(5, 4, 3)
H = hessian(fun)(a)
check_equivalent(np.tensordot(H, V, axes=np.ndim(V)), hessian_vector_product(fun)(a, V))
check_equivalent(np.tensordot(H, V, axes=np.ndim(V)), hessian_tensor_product(fun)(a, V))

def test_vector_jacobian_product():
def test_tensor_jacobian_product():
# This function will have an asymmetric jacobian matrix.
fun = lambda a: np.roll(np.sin(a), 1)
a = npr.randn(5)
V = npr.randn(5)
J = jacobian(fun)(a)
check_equivalent(np.dot(V.T, J), vector_jacobian_product(fun)(a, V))
check_equivalent(np.dot(V.T, J), tensor_jacobian_product(fun)(a, V))

def test_matrix_jacobian_product():
fun = lambda a: np.roll(np.sin(a), 1)
a = npr.randn(5, 4)
V = npr.randn(5, 4)
J = jacobian(fun)(a)
check_equivalent(np.tensordot(V, J), vector_jacobian_product(fun)(a, V))
check_equivalent(np.tensordot(V, J), tensor_jacobian_product(fun)(a, V))

def test_tensor_jacobian_product():
fun = lambda a: np.roll(np.sin(a), 1)
a = npr.randn(5, 4, 3)
V = npr.randn(5, 4)
J = jacobian(fun)(a)
check_equivalent(np.tensordot(V, J, axes=np.ndim(V)), vector_jacobian_product(fun)(a, V))
check_equivalent(np.tensordot(V, J, axes=np.ndim(V)), tensor_jacobian_product(fun)(a, V))

def test_deprecated_defgrad_wrapper():
@primitive
Expand Down Expand Up @@ -234,3 +234,57 @@ def testfun(f, x):
max_checkpointed_usage = max(memory_usage((gradfun, (checkpointed_f, A))))

assert max_checkpointed_usage < max_usage / 2.

def test_make_jvp():
A = npr.randn(3, 5)
x = npr.randn(5)
v = npr.randn(5)
fun = lambda x: np.tanh(np.dot(A, x))

jvp_explicit = lambda x: lambda v: np.dot(jacobian(fun)(x), v)
jvp = make_jvp(fun)

check_equivalent(jvp_explicit(x)(v), jvp(x)(v))

def _make_explicit_ggnvp(f, g=lambda x: 1./2*np.dot(x, x)):
def ggnvp_maker(x):
J = jacobian(f)(x)
H = hessian(g)(f(x))
def ggnvp(v):
return np.dot(J.T, np.dot(H, np.dot(J, v)))
return ggnvp
return ggnvp_maker

def test_make_ggnvp():
A = npr.randn(5, 4)
x = npr.randn(4)
v = npr.randn(4)

fun = lambda x: np.dot(A, x)
check_equivalent(make_ggnvp(fun)(x)(v), _make_explicit_ggnvp(fun)(x)(v))

fun2 = lambda x: np.tanh(np.dot(A, x))
check_equivalent(make_ggnvp(fun2)(x)(v), _make_explicit_ggnvp(fun2)(x)(v))

def test_make_ggnvp_nondefault_g():
A = npr.randn(5, 4)
x = npr.randn(4)
v = npr.randn(4)

g = lambda y: np.sum(2.*y**2 + y**4)

fun = lambda x: np.dot(A, x)
check_equivalent(make_ggnvp(fun, g)(x)(v), _make_explicit_ggnvp(fun, g)(x)(v))

fun2 = lambda x: np.tanh(np.dot(A, x))
check_equivalent(make_ggnvp(fun2, g)(x)(v), _make_explicit_ggnvp(fun2, g)(x)(v))

def test_make_ggnvp_broadcasting():
A = npr.randn(4, 5)
x = npr.randn(10, 4)
v = npr.randn(10, 4)

fun = lambda x: np.tanh(np.dot(x, A))
res1 = np.stack([_make_explicit_ggnvp(fun)(xi)(vi) for xi, vi in zip(x, v)])
res2 = make_ggnvp(fun)(x)(v)
check_equivalent(res1, res2)

0 comments on commit a055b42

Please sign in to comment.