Skip to content

Commit

Permalink
feat(imperative/trace): add scope info to the opr in the backward phase
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c13dcc056e9328a340e24051d2a826ffb552b6dd
  • Loading branch information
megvii-mge committed Dec 18, 2023
1 parent 70127c1 commit 4d6aaa8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 2 additions & 2 deletions imperative/python/megengine/xla/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def _str_shape(shp):
def _str_eqn(self, eqn):
inps = ", ".join(map(self._str_var, eqn.inputs))
oups = ", ".join(map(self._str_var, eqn.outputs))
str_op = str(eqn.op)
str_op = str(eqn.type)
if isinstance(eqn.op, mops.Reduce):
assert str(eqn.op.mode).startswith("Reduce.Mode.")
str_op = str_op + str(eqn.op.mode)[len("Reduce.Mode.") :]
ret = f"{oups} = {str_op}({inps})"
ret = f"{oups} = {str_op}({inps}) scope: {eqn.scope}"
return ret

def __str__(self) -> str:
Expand Down
8 changes: 8 additions & 0 deletions imperative/python/src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,14 @@ void init_tensor(py::module m) {
else
return py::cast(opkind2str.find(self.kind)->second);
})
.def_property_readonly(
"scope",
[](SeqItem& self) -> py::object {
if (self.op && !self.op->scope().empty()) {
return py::cast(self.op->scope());
}
return py::none();
})
.def_property_readonly(
"kind",
[opkind2str](SeqItem& self) {
Expand Down
8 changes: 5 additions & 3 deletions imperative/src/impl/transformations/grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ void GradKey::backward() {
auto& tape = m_frozen_tape;
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto& [grad_fn, op] = tape[k];
std::string scope_name = op ? op->make_name() + ".Backward" : "CustomBackward";
Transformation::push_scope(scope_name);
auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) {
auto& dest = grad_fn->m_dests[i];
if (dest) {
Expand All @@ -233,13 +235,12 @@ void GradKey::backward() {
for (auto&& slot : grad_fn->m_slots) {
*iter++ = slot.m_grad;
}
std::string name = op ? op->name() + "Backward" : "CustomBackward";
if (Profiler::is_profiling()) {
imperative::apply(PushScope(name, ScopeType::BACKWARD), Span<ValueRef>(nullptr, nullptr));
imperative::apply(PushScope(scope_name, ScopeType::BACKWARD), Span<ValueRef>(nullptr, nullptr));
}
backward(grads, grad_receiver);
if (Profiler::is_profiling()) {
imperative::apply(PopScope(name, ScopeType::BACKWARD), Span<ValueRef>(nullptr, nullptr));
imperative::apply(PopScope(scope_name, ScopeType::BACKWARD), Span<ValueRef>(nullptr, nullptr));
}
}
}, grad_fn->m_backward);
Expand All @@ -256,6 +257,7 @@ void GradKey::backward() {
}
}
grad_fn->clear();
Transformation::pop_scope(scope_name);
}
tape.clear();
}
Expand Down

0 comments on commit 4d6aaa8

Please sign in to comment.