Skip to content

Commit

Permalink
[Relay][Training] Add gradient for max. (#3915)
Browse files Browse the repository at this point in the history
* save

* save
  • Loading branch information
MarisaKirisame authored and tmoreau89 committed Sep 9, 2019
1 parent 83d2418 commit 0f4c151
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
14 changes: 13 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
from .transform import (
broadcast_to_like,
collapse_sum_like,
Expand Down Expand Up @@ -269,6 +269,18 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight]


@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x, axis = orig.args[0], orig.attrs.axis
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x)
grad = broadcast_to_like(grad, x)
indicators = cast_like(equal(orig, x), grad)
return [indicators * grad]


@register_gradient("nn.softmax")
def softmax_grad(orig, grad):
"""Gradient of softmax"""
Expand Down
13 changes: 12 additions & 1 deletion tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm import relay
from tvm.relay.testing import check_grad

Expand All @@ -30,6 +31,16 @@ def test_sum_grad():
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)


def test_max_grad():
s = (5, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)

fwd_func = relay.Function([x], z)
check_grad(fwd_func, eps=1e-7, rtol=1)


if __name__ == "__main__":
test_sum_grad()
pytest.main()

0 comments on commit 0f4c151

Please sign in to comment.