Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a few gradients #5899

Merged
merged 1 commit into from Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 57 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Expand Up @@ -270,6 +270,14 @@ def abs_grad(orig, grad):
return [where(less(x, zeros), -ones * grad, ones * grad)]


@register_gradient("erf")
def erf_grad(orig, grad):
# c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi)
inp, = orig.args
c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype)
return [c_2_div_sqrt_pi * exp(- inp * inp) * grad]


@register_gradient("clip")
def clip_grad(orig, grad):
"""Returns grad * (select(x < min || max < x , 0, 1))."""
Expand Down Expand Up @@ -479,6 +487,19 @@ def dense_grad(orig, grad):
collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
units=data.checked_type.shape[1]), weight)]


@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
"""
lhs, rhs = orig.args
return [collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
collapse_sum_like(_nn.batch_matmul(transpose(grad, [0, 2, 1]),
transpose(lhs, [0, 2, 1])), rhs)]


@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
Expand Down Expand Up @@ -529,6 +550,42 @@ def sum_grad(orig, grad):
return [broadcast_to_like(grad, data)]


@register_gradient("mean")
def mean_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data, axis = orig.args[0], _get_reduce_axis(orig)
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 1.0
for a in axis:
mult /= shape[a]
return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)]


@register_gradient("variance")
def variance_grad(orig, grad):
"""Note that we take mean as an argument in the variance node"""
data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 2.0
for a in axis:
mult /= shape[a]
return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
const(-2, dtype=data.checked_type.dtype) * grad * data_mean]


@register_gradient("copy")
def copy_grad(orig, grad):
return [grad]


@register_gradient("nn.cross_entropy")
def cross_entropy_grad(orig, grad):
x, y = orig.args
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_grad_level1.py
Expand Up @@ -62,6 +62,7 @@ def check_single_op(opfunc, ref):
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.erf, lambda x: 2.0 / (np.pi**(0.5)) * np.exp(-x * x)),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Expand Up @@ -44,5 +44,11 @@ def test_checkpoint():
check_grad(relay.Function(inputs, out_single))


def test_batch_matmul_grad():
x = relay.var("x", shape=(2, 3, 5), dtype="float64")
y = relay.var("y", shape=(2, 4, 5), dtype="float64")
check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y)))


if __name__ == "__main__":
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/python/relay/test_op_grad_level3.py
Expand Up @@ -64,5 +64,12 @@ def test_cast_grad():
fwd_func = relay.Function([data], relay.cast(data, "float64"))
check_grad(fwd_func)


def test_copy_grad():
data = relay.var("data", relay.TensorType((10, 4), "float64"))
fwd_func = relay.Function([data], relay.copy(data))
check_grad(fwd_func)


if __name__ == "__main__":
pytest.main()
15 changes: 8 additions & 7 deletions tests/python/relay/test_op_grad_level4.py
Expand Up @@ -19,17 +19,18 @@
from tvm.relay.testing import check_grad


def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude))
fwd_func = relay.Function([data], red_fn(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func)


def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
verify_sum_grad((4, 2, 1), axis=1)
def test_reduction_grad():
for op in (relay.sum, relay.variance, relay.mean):
verify_reduction_grad(op, (4, 2))
verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
verify_reduction_grad(op, (4, 2, 1), axis=1)


def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
Expand Down