From e511c1878ecdff0bce1ca9057e610654f9f2a563 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 22 Feb 2018 18:46:46 -0500 Subject: [PATCH] Express matmul vjps in terms of matmul --- autograd/numpy/numpy_vjps.py | 55 +++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 8b79bd44..4599dea2 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -317,19 +317,48 @@ def grad_inner(argnum, ans, A, B): return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim) defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1)) -def grad_matmul(argnum, ans, A, B): - A_ndim, B_ndim = anp.ndim(A), anp.ndim(B) - if A_ndim == 0 or B_ndim == 0: - raise ValueError("Scalar operands are not allowed, use '*' instead") - elif A_ndim == 1 or B_ndim == 1 or (A_ndim == 2 and B_ndim == 2): - axes = ([A_ndim - 1], [max(0, B_ndim - 2)]) - if argnum == 0: - return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim)) - elif argnum == 1: - return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)) - else: - return grad_einsum(argnum + 1, ans, ("...ij,...jk->...ik", A, B), None) -defvjp(anp.matmul, partial(grad_matmul, 0), partial(grad_matmul, 1)) +def matmul_adjoint_0(B, G, A_meta, B_ndim): + if anp.ndim(G) == 0: # A_ndim == B_ndim == 1 + return unbroadcast(G * B, A_meta) + _, A_ndim, _, _ = A_meta + if A_ndim == 1: + G = anp.expand_dims(G, anp.ndim(G) - 1) + if B_ndim == 1: # The result we need is an outer product + B = anp.expand_dims(B, 0) + G = anp.expand_dims(G, anp.ndim(G)) + else: # We need to swap the last two axes of B + B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1) + result = anp.matmul(G, B) + return unbroadcast(result, A_meta) + +def matmul_adjoint_1(A, G, A_ndim, B_meta): + if anp.ndim(G) == 0: # A_ndim == B_ndim == 1 + return unbroadcast(G * A, B_meta) + _, B_ndim, _, _ = B_meta + B_is_vec = (B_ndim == 1) + if B_is_vec: + G = anp.expand_dims(G, anp.ndim(G)) + if A_ndim == 1: # The result we need is an outer product + A = anp.expand_dims(A, 1) + G = anp.expand_dims(G, anp.ndim(G) - 1) + else: # We need to swap the last two axes of A + A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1) + result = anp.matmul(A, G) + if B_is_vec: + result = anp.squeeze(result, anp.ndim(G) - 1) + return unbroadcast(result, B_meta) + +def matmul_vjp_0(ans, A, B): + A_meta = anp.metadata(A) + B_ndim = anp.ndim(B) + return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim) + +def matmul_vjp_1(ans, A, B): + A_ndim = anp.ndim(A) + B_meta = anp.metadata(B) + return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta) + +defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1) @primitive def dot_adjoint_0(B, G, A_ndim, B_ndim):