diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 06d2e09b..8b79bd44 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -523,8 +523,7 @@ def vjp(g): new_operands = (g,) + rest_of_ops new_subscripts = new_input_subs + '->' + subs_wrt - # TODO(mattjj): remove optimize=False after github.com/numpy/numpy/issues/10343 - return unbroadcast(anp.einsum(new_subscripts, *new_operands, optimize=False), result_meta) + return unbroadcast(anp.einsum(new_subscripts, *new_operands), result_meta) else: # using (op0, sublist0, op1, sublist1, ..., sublistout) convention if len(operands) % 2 == 0: raise NotImplementedError("Need sublistout argument")