Skip to content

Commit

Permalink
Merge pull request #9384 from luotao1/removeVar
Browse files Browse the repository at this point in the history
remove vars when remove ops
  • Loading branch information
luotao1 committed Mar 28, 2018
2 parents 47e4afb + 7f40122 commit 857a899
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
49 changes: 43 additions & 6 deletions paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,52 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
std::vector<std::string> &v) {
auto in_names = (*op)->InputArgumentNames();
v.insert(v.end(), in_names.begin(), in_names.end());
auto out_names = (*op)->OutputArgumentNames();
v.insert(v.end(), out_names.begin(), out_names.end());
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
};
need_update_ = true;
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames();
for (auto n : names) {
// TODO(typhoonzero): delete vars if no other op use it.
VLOG(3) << "deleting var " << n;

for (size_t i = s; i < e; i++) {
// since remove op one by one, every time remove the first op.
auto op = ops_.begin() + s;

// collect input and output variables from current delete op
std::vector<std::string> cur_vars;
get_vars(op, cur_vars);

// remove current op
ops_.erase(ops_.begin() + s);

// collect input and output variables from other ops
std::vector<std::string> other_vars;
for (auto it = ops_.begin(); it != ops_.end(); it++) {
get_vars(it, other_vars);
}

// variables should be deleted
std::vector<std::string> delete_vars;
// delete_vars = cur_vars - cur_vars ^ other_input_vars
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
other_vars.end(),
std::inserter(delete_vars, delete_vars.end()));
// remove variables
for (size_t i = 0; i < delete_vars.size(); i++) {
auto name = delete_vars[i];
auto it = vars_.find(name);
PADDLE_ENFORCE(it != vars_.end(),
"%s is not in variable list, it should not be deleted",
name);
vars_.erase(it);
VLOG(3) << "deleting variable " << name;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}

std::vector<OpDesc *> BlockDesc::AllOps() const {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class BlockDesc {

OpDesc *InsertOp(size_t index);

/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e);

std::vector<OpDesc *> AllOps() const;
Expand Down
28 changes: 28 additions & 0 deletions python/paddle/fluid/tests/unittests/test_protobuf_descs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,34 @@ def test_add_op(self):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op1, op2])

def test_remove_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var4 = block.var("var4")
var5 = block.var("var5")
op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4", "var5"])

# remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(0, 1)

all_ops = []
for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2])
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4, var5})


if __name__ == '__main__':
unittest.main()

0 comments on commit 857a899

Please sign in to comment.