Skip to content

Commit

Permalink
[RELAY][GRAD] handle Tuple/TupleGetItem in first order gradient (#5946)
Browse files Browse the repository at this point in the history
* handle Tuple/TupleGetItem in first order gradient

* Unify MultiOnes/MultiZeros.
  • Loading branch information
t-vi committed Jun 30, 2020
1 parent 5d445ca commit 957aefb
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 9 deletions.
79 changes: 77 additions & 2 deletions src/relay/transforms/gradient.cc
Expand Up @@ -106,14 +106,44 @@ struct ADValueNode {
}
};

template <typename F>
Expr MultiFactory(const Type& t, F factory) {
if (auto* tt = t.as<TensorTypeNode>()) {
return factory(tt->shape, tt->dtype);
} else if (auto* tt = t.as<TupleTypeNode>()) {
std::vector<Expr> res;
for (size_t i = 0; i < tt->fields.size(); i++) {
res.push_back(MultiFactory(tt->fields[i], factory));
}
return Tuple(res);
} else {
LOG(FATAL) << "unsupported type to create tensors of: " << tt;
throw;
}
}

template <typename F, typename F2>
Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) {
if (t.as<TensorTypeNode>()) {
return factory_like(e);
} else if (auto* tt = t.as<TupleTypeNode>()) {
return MultiFactory(t, factory);
} else {
LOG(FATAL) << "unsupported type to tensors of: " << tt;
throw;
}
}

using ADValue = std::shared_ptr<ADValueNode>;

/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& forward)
: forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
: forward(ll->Push(forward)),
reverse(
ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) {
this->forward->checked_type_ = forward->checked_type();
}
};
Expand Down Expand Up @@ -165,6 +195,51 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
});
}

ADValue VisitExpr_(const TupleGetItemNode* op) final {
Expr e = GetRef<Expr>(op);
ADValue tup = VisitExpr(op->tuple);
auto tt = op->tuple->checked_type().as<TupleTypeNode>();
size_t size = tt->fields.size();
size_t idx = op->index;
auto ret = std::make_shared<ADTensor>(ll, e);
backprop_actions.push_back([tup, idx, size, ret](LetList* ll) {
auto rev = tup->get<ADTensor>().reverse;
// special-case Tuple, to avoid long chains of GetItem/Tuple,
// but we might have functions using tuples, so we don't know
// that the reverse node is always a tuple
std::vector<Expr> grfields;
if (auto tup_node = rev.as<TupleNode>()) {
for (size_t i = 0; i < size; ++i) {
grfields.push_back(i != idx ? tup_node->fields[i]
: Add(tup_node->fields[i], ret->reverse));
}
} else {
for (size_t i = 0; i < size; ++i) {
grfields.push_back(i != idx ? TupleGetItem(rev, i)
: Add(TupleGetItem(rev, i), ret->reverse));
}
}
tup->get<ADTensor>().reverse = ll->Push(Tuple(grfields));
});
return ret;
}

ADValue VisitExpr_(const TupleNode* op) final {
Expr e = GetRef<Expr>(op);
std::vector<ADValue> fields;
for (const auto& f : op->fields) {
fields.push_back(VisitExpr(f));
}
auto ret = std::make_shared<ADTensor>(ll, e);
backprop_actions.push_back([fields, ret](LetList* ll) {
for (size_t i = 0; i < fields.size(); ++i) {
fields[i]->get<ADTensor>().reverse =
ll->Push(Add(fields[i]->get<ADTensor>().reverse, TupleGetItem(ret->reverse, i)));
}
});
return ret;
}

ADValue VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return std::make_shared<ADTensor>(ll, e);
Expand Down Expand Up @@ -235,7 +310,7 @@ Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {
auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OnesLike(res.forward);
res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike);
for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/transforms/pattern_util.h
Expand Up @@ -524,6 +524,12 @@ inline Expr OnesLike(Expr e) {
return Call(op, {e});
}

Expr MakeOnes(Expr shape, DataType dtype);

inline Expr Ones(Array<IndexExpr> shape, DataType dtype) {
return MakeOnes(CheckConstantShape(shape), dtype);
}

inline Expr CollapseSumLike(Expr e) {
static const Op& op = Op::Get("collapse_sum_like");
return Call(op, {e});
Expand Down
28 changes: 21 additions & 7 deletions tests/python/relay/test_pass_gradient.py
Expand Up @@ -158,20 +158,27 @@ def test_broadcast_subtract():
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))


def test_tuple():
def _test_tuple(mode):
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
z = relay.var("z", t)
tup = relay.Var("tup")
func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]),
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
if mode == "higher_order":
tup = relay.Var("tup")
func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]),
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
else:
# first order does not do let.
tup = relay.Tuple([x, y, z])
func = relay.Function([x, y, z], relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2))
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
back_func = run_infer_type(gradient(func, mode=mode))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape)
y_nd = rand(dtype, *shape)
Expand All @@ -188,6 +195,12 @@ def test_tuple():
tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))


def test_tuple():
_test_tuple("higher_order")

def test_tuple_first_order():
_test_tuple("first_order")

def test_pow():
mod = tvm.IRModule()
p = Prelude(mod)
Expand Down Expand Up @@ -304,6 +317,7 @@ def test_concat():
test_broadcast_add()
test_broadcast_subtract()
test_tuple()
test_tuple_first_order()
test_pow()
test_ref()
test_square_second_order()
Expand Down

0 comments on commit 957aefb

Please sign in to comment.