Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】modify test_case.py #60976

Merged
merged 3 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,14 @@ void IfOp::VerifyRegion() {
1u,
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
(*this)->region(0).size()));
if ((*this)->region(0).front().size() > 0) {
auto &true_last_op = (*this)->region(0).front().back();
if ((*this)->num_results() != 0) {
auto &true_block = (*this)->region(0).front();
PADDLE_ENFORCE_GT(
true_block.size(),
0u,
phi::errors::PreconditionNotMet(
"The true block must have at least one op yield op."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The true block must have at least one op yield op."));
"The true block must have at least one op (inlcuding yield op)."));

auto &true_last_op = true_block.back();
PADDLE_ENFORCE_EQ(true,
true_last_op.isa<pir::YieldOp>(),
phi::errors::PreconditionNotMet(
Expand All @@ -228,15 +234,19 @@ void IfOp::VerifyRegion() {
phi::errors::PreconditionNotMet(
"The size of last of true block op's input must be "
"equal to IfOp's outputs num."));
}
VLOG(4) << "Start Verifying false branch.";
PADDLE_ENFORCE_EQ(
(*this)->region(1).size(),
1u,
phi::errors::PreconditionNotMet("The size %d of false_region must be 1.",
(*this)->region(0).size()));
if ((*this)->region(1).front().size() > 0) {
auto &false_last_op = (*this)->region(1).front().back();
VLOG(4) << "Start Verifying false branch.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false branch 的 verify 建议放在最上面的 if ((*this)->num_results() != 0) 分支语句之外,因为我们要求 if 分支不能只有一个分支,现在的逻辑下,如果 num_results == 0,这个要求就检查不到了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个功能是为了在num_results= 0 的时候跳过检查,允许if 无yield op

PADDLE_ENFORCE_EQ((*this)->region(1).size(),
1u,
phi::errors::PreconditionNotMet(
"The size %d of false_region must be 1.",
(*this)->region(0).size()));
auto &false_block = (*this)->region(1).front();
PADDLE_ENFORCE_GT(
false_block.size(),
0u,
phi::errors::PreconditionNotMet(
"The false block must have at least one op yield op."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The false block must have at least one op yield op."));
"The false block must have at least one op(including yield op)."));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续提pr 优化

auto &false_last_op = false_block.back();
PADDLE_ENFORCE_EQ(true,
false_last_op.isa<pir::YieldOp>(),
phi::errors::PreconditionNotMet(
Expand Down
27 changes: 20 additions & 7 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,21 @@ def all_stop_gradient_true(block):
return True


def update_total_ops(block):
'''
when block is sub_block, forward op should include its parent block ops
(sub block nest should Add on demand to aviod block copy)
'''
total_ops = []
if block.parent_block is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 total 的语义是只包括入参block的父亲和祖父 block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是所有的祖block ,但是为了避免使用while,造成block 的拷贝,只限制了两层

if block.parent_block.parent_block:
total_ops += block.parent_block.parent_block.ops
total_ops += block.parent_block.ops
total_ops += block.ops

return total_ops


def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block = outputs[0].get_defining_op().get_parent_block()
state = State(block)
Expand All @@ -1067,16 +1082,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
)
return state.value_to_valuegrad

total_ops = []
if block.parent_block is not None:
total_ops += block.parent_block.ops
total_ops += block.ops
total_ops = update_total_ops(block)

# update no_grad_set if some value stop_gradient=True
update_no_grad_set_by_stopgradient(block, no_grad_set)
complete_outputs, backward_ops = prepare_grad_outputs(
grad_outputs, outputs, state
)
with block:
complete_outputs, backward_ops = prepare_grad_outputs(
grad_outputs, outputs, state
)

inputs_set = ValueSet(inputs)
stop_gradient_false_outputs = []
Expand Down
97 changes: 50 additions & 47 deletions test/legacy_test/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def fn_3():
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)

print(main_program)
xiaoguoguo626807 marked this conversation as resolved.
Show resolved Hide resolved
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
Expand Down Expand Up @@ -611,58 +611,61 @@ def type_error_default():

# when optimizer in case
class TestMutiTask(unittest.TestCase):
@test_with_pir_api
def test_optimizer_in_case(self):
BATCH_SIZE = 1
INPUT_SIZE = 784
EPOCH_NUM = 2

x = paddle.static.data(
name='x', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32'
)
y = paddle.static.data(
name='y', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32'
)

switch_id = paddle.static.data(
name='switch_id', shape=[1], dtype='int32'
)

one = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1)
adam = paddle.optimizer.Adam(learning_rate=0.001)
adagrad = paddle.optimizer.Adagrad(learning_rate=0.001)

def fn_1():
sum = paddle.multiply(x, y)
loss = paddle.mean(sum, name="f_1_loss")
adam.minimize(loss)

def fn_2():
sum = paddle.multiply(x, y)
loss = paddle.mean(sum, name="f_2_loss")
adagrad.minimize(loss)

paddle.static.nn.control_flow.case(
pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2
)

exe = base.Executor(base.CPUPlace())
exe.run(base.default_startup_program())

for epoch in range(EPOCH_NUM):
np.random.seed(epoch)
feed_image = np.random.random(size=[BATCH_SIZE, INPUT_SIZE]).astype(
'float32'
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name='x', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32'
)
main_program = base.default_main_program()
out = exe.run(
main_program,
feed={
'x': feed_image,
'y': feed_image,
'switch_id': np.array([epoch]).astype('int32'),
},
fetch_list=[],
y = paddle.static.data(
name='y', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32'
)
x.stop_gradient = False
y.stop_gradient = False
switch_id = paddle.static.data(
name='switch_id', shape=[1], dtype='int32'
)

one = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1)
adam = paddle.optimizer.Adam(learning_rate=0.001)
adagrad = paddle.optimizer.Adagrad(learning_rate=0.001)

def fn_1():
sum = paddle.multiply(x, y)
loss = paddle.mean(sum, name="f_1_loss")
adam.minimize(loss)

def fn_2():
sum = paddle.multiply(x, y)
loss = paddle.mean(sum, name="f_2_loss")
adagrad.minimize(loss)

paddle.static.nn.control_flow.case(
pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2
)

exe = base.Executor(base.CPUPlace())
exe.run(startup_program)

for epoch in range(EPOCH_NUM):
np.random.seed(epoch)
feed_image = np.random.random(
size=[BATCH_SIZE, INPUT_SIZE]
).astype('float32')
out = exe.run(
main_program,
feed={
'x': feed_image,
'y': feed_image,
'switch_id': np.array([epoch]).astype('int32'),
},
fetch_list=[],
)


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion test/legacy_test/test_dynamic_rnn_stop_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle import base
from paddle.pit_utils import test_with_pir_api
xiaoguoguo626807 marked this conversation as resolved.
Show resolved Hide resolved
from paddle.tensor.manipulation import tensor_array_to_tensor

paddle.enable_static()
Expand Down Expand Up @@ -77,7 +78,7 @@ def setUp(self):
self.batch_size = 2
self.beam_size = 2

# @test_with_pir_api
@test_with_pir_api
def run_main(self, place):
with paddle.pir_utils.IrGuard():
main_program = paddle.static.Program()
Expand Down