diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 30d5ce5a1b685..9e3f283c95ec1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -217,8 +217,14 @@ void IfOp::VerifyRegion() { 1u, phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); - if ((*this)->region(0).front().size() > 0) { - auto &true_last_op = (*this)->region(0).front().back(); + if ((*this)->num_results() != 0) { + auto &true_block = (*this)->region(0).front(); + PADDLE_ENFORCE_GT( + true_block.size(), + 0u, + phi::errors::PreconditionNotMet( + "The true block must have at least one op yield op.")); + auto &true_last_op = true_block.back(); PADDLE_ENFORCE_EQ(true, true_last_op.isa(), phi::errors::PreconditionNotMet( @@ -228,15 +234,19 @@ void IfOp::VerifyRegion() { phi::errors::PreconditionNotMet( "The size of last of true block op's input must be " "equal to IfOp's outputs num.")); - } - VLOG(4) << "Start Verifying false branch."; - PADDLE_ENFORCE_EQ( - (*this)->region(1).size(), - 1u, - phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", - (*this)->region(0).size())); - if ((*this)->region(1).front().size() > 0) { - auto &false_last_op = (*this)->region(1).front().back(); + VLOG(4) << "Start Verifying false branch."; + PADDLE_ENFORCE_EQ((*this)->region(1).size(), + 1u, + phi::errors::PreconditionNotMet( + "The size %d of false_region must be 1.", + (*this)->region(0).size())); + auto &false_block = (*this)->region(1).front(); + PADDLE_ENFORCE_GT( + false_block.size(), + 0u, + phi::errors::PreconditionNotMet( + "The false block must have at least one op yield op.")); + auto &false_last_op = false_block.back(); PADDLE_ENFORCE_EQ(true, false_last_op.isa(), phi::errors::PreconditionNotMet( diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 0e0dc739b989e..4779edf0418ac 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -1058,6 +1058,21 @@ def all_stop_gradient_true(block): return True +def update_total_ops(block): + ''' + when block is sub_block, forward op should include its parent block ops + (sub block nest should Add on demand to aviod block copy) + ''' + total_ops = [] + if block.parent_block is not None: + if block.parent_block.parent_block: + total_ops += block.parent_block.parent_block.ops + total_ops += block.parent_block.ops + total_ops += block.ops + + return total_ops + + def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): block = outputs[0].get_defining_op().get_parent_block() state = State(block) @@ -1067,16 +1082,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): ) return state.value_to_valuegrad - total_ops = [] - if block.parent_block is not None: - total_ops += block.parent_block.ops - total_ops += block.ops + total_ops = update_total_ops(block) # update no_grad_set if some value stop_gradient=True update_no_grad_set_by_stopgradient(block, no_grad_set) - complete_outputs, backward_ops = prepare_grad_outputs( - grad_outputs, outputs, state - ) + with block: + complete_outputs, backward_ops = prepare_grad_outputs( + grad_outputs, outputs, state + ) inputs_set = ValueSet(inputs) stop_gradient_false_outputs = [] diff --git a/test/legacy_test/test_case.py b/test/legacy_test/test_case.py index 1a5cf3e459e6b..93bcff88a1891 100644 --- a/test/legacy_test/test_case.py +++ b/test/legacy_test/test_case.py @@ -415,7 +415,6 @@ def fn_3(): out_1 = paddle.static.nn.control_flow.case( pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3 ) - out_2 = paddle.static.nn.control_flow.case( pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 ) @@ -611,58 +610,61 @@ def type_error_default(): # when optimizer in case class TestMutiTask(unittest.TestCase): + @test_with_pir_api def test_optimizer_in_case(self): BATCH_SIZE = 1 INPUT_SIZE = 784 EPOCH_NUM = 2 - - x = paddle.static.data( - name='x', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32' - ) - y = paddle.static.data( - name='y', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32' - ) - - switch_id = paddle.static.data( - name='switch_id', shape=[1], dtype='int32' - ) - - one = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1) - adam = paddle.optimizer.Adam(learning_rate=0.001) - adagrad = paddle.optimizer.Adagrad(learning_rate=0.001) - - def fn_1(): - sum = paddle.multiply(x, y) - loss = paddle.mean(sum, name="f_1_loss") - adam.minimize(loss) - - def fn_2(): - sum = paddle.multiply(x, y) - loss = paddle.mean(sum, name="f_2_loss") - adagrad.minimize(loss) - - paddle.static.nn.control_flow.case( - pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2 - ) - - exe = base.Executor(base.CPUPlace()) - exe.run(base.default_startup_program()) - - for epoch in range(EPOCH_NUM): - np.random.seed(epoch) - feed_image = np.random.random(size=[BATCH_SIZE, INPUT_SIZE]).astype( - 'float32' + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + name='x', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32' ) - main_program = base.default_main_program() - out = exe.run( - main_program, - feed={ - 'x': feed_image, - 'y': feed_image, - 'switch_id': np.array([epoch]).astype('int32'), - }, - fetch_list=[], + y = paddle.static.data( + name='y', shape=[BATCH_SIZE, INPUT_SIZE], dtype='float32' ) + x.stop_gradient = False + y.stop_gradient = False + switch_id = paddle.static.data( + name='switch_id', shape=[1], dtype='int32' + ) + + one = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1) + adam = paddle.optimizer.Adam(learning_rate=0.001) + adagrad = paddle.optimizer.Adagrad(learning_rate=0.001) + + def fn_1(): + sum = paddle.multiply(x, y) + loss = paddle.mean(sum, name="f_1_loss") + adam.minimize(loss) + + def fn_2(): + sum = paddle.multiply(x, y) + loss = paddle.mean(sum, name="f_2_loss") + adagrad.minimize(loss) + + paddle.static.nn.control_flow.case( + pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2 + ) + + exe = base.Executor(base.CPUPlace()) + exe.run(startup_program) + + for epoch in range(EPOCH_NUM): + np.random.seed(epoch) + feed_image = np.random.random( + size=[BATCH_SIZE, INPUT_SIZE] + ).astype('float32') + out = exe.run( + main_program, + feed={ + 'x': feed_image, + 'y': feed_image, + 'switch_id': np.array([epoch]).astype('int32'), + }, + fetch_list=[], + ) if __name__ == '__main__': diff --git a/test/legacy_test/test_dynamic_rnn_stop_gradient.py b/test/legacy_test/test_dynamic_rnn_stop_gradient.py index 7d28931887d7c..c6d85b864c8d7 100644 --- a/test/legacy_test/test_dynamic_rnn_stop_gradient.py +++ b/test/legacy_test/test_dynamic_rnn_stop_gradient.py @@ -18,6 +18,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api from paddle.tensor.manipulation import tensor_array_to_tensor paddle.enable_static() @@ -77,7 +78,7 @@ def setUp(self): self.batch_size = 2 self.beam_size = 2 - # @test_with_pir_api + @test_with_pir_api def run_main(self, place): with paddle.pir_utils.IrGuard(): main_program = paddle.static.Program()