-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【pir】modify test_case.py #60976
【pir】modify test_case.py #60976
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -217,8 +217,14 @@ void IfOp::VerifyRegion() { | |||||
1u, | ||||||
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", | ||||||
(*this)->region(0).size())); | ||||||
if ((*this)->region(0).front().size() > 0) { | ||||||
auto &true_last_op = (*this)->region(0).front().back(); | ||||||
if ((*this)->num_results() != 0) { | ||||||
auto &true_block = (*this)->region(0).front(); | ||||||
PADDLE_ENFORCE_GT( | ||||||
true_block.size(), | ||||||
0u, | ||||||
phi::errors::PreconditionNotMet( | ||||||
"The true block must have at least one op yield op.")); | ||||||
auto &true_last_op = true_block.back(); | ||||||
PADDLE_ENFORCE_EQ(true, | ||||||
true_last_op.isa<pir::YieldOp>(), | ||||||
phi::errors::PreconditionNotMet( | ||||||
|
@@ -228,15 +234,19 @@ void IfOp::VerifyRegion() { | |||||
phi::errors::PreconditionNotMet( | ||||||
"The size of last of true block op's input must be " | ||||||
"equal to IfOp's outputs num.")); | ||||||
} | ||||||
VLOG(4) << "Start Verifying false branch."; | ||||||
PADDLE_ENFORCE_EQ( | ||||||
(*this)->region(1).size(), | ||||||
1u, | ||||||
phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", | ||||||
(*this)->region(0).size())); | ||||||
if ((*this)->region(1).front().size() > 0) { | ||||||
auto &false_last_op = (*this)->region(1).front().back(); | ||||||
VLOG(4) << "Start Verifying false branch."; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. false branch 的 verify 建议放在最上面的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个功能是为了在num_results= 0 的时候跳过检查,允许if 无yield op |
||||||
PADDLE_ENFORCE_EQ((*this)->region(1).size(), | ||||||
1u, | ||||||
phi::errors::PreconditionNotMet( | ||||||
"The size %d of false_region must be 1.", | ||||||
(*this)->region(0).size())); | ||||||
auto &false_block = (*this)->region(1).front(); | ||||||
PADDLE_ENFORCE_GT( | ||||||
false_block.size(), | ||||||
0u, | ||||||
phi::errors::PreconditionNotMet( | ||||||
"The false block must have at least one op yield op.")); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续提pr 优化 |
||||||
auto &false_last_op = false_block.back(); | ||||||
PADDLE_ENFORCE_EQ(true, | ||||||
false_last_op.isa<pir::YieldOp>(), | ||||||
phi::errors::PreconditionNotMet( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1058,6 +1058,21 @@ def all_stop_gradient_true(block): | |
return True | ||
|
||
|
||
def update_total_ops(block): | ||
''' | ||
when block is sub_block, forward op should include its parent block ops | ||
(sub block nest should Add on demand to aviod block copy) | ||
''' | ||
total_ops = [] | ||
if block.parent_block is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 total 的语义是只包括入参block的父亲和祖父 block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 理论上是所有的祖block ,但是为了避免使用while,造成block 的拷贝,只限制了两层 |
||
if block.parent_block.parent_block: | ||
total_ops += block.parent_block.parent_block.ops | ||
total_ops += block.parent_block.ops | ||
total_ops += block.ops | ||
|
||
return total_ops | ||
|
||
|
||
def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): | ||
block = outputs[0].get_defining_op().get_parent_block() | ||
state = State(block) | ||
|
@@ -1067,16 +1082,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): | |
) | ||
return state.value_to_valuegrad | ||
|
||
total_ops = [] | ||
if block.parent_block is not None: | ||
total_ops += block.parent_block.ops | ||
total_ops += block.ops | ||
total_ops = update_total_ops(block) | ||
|
||
# update no_grad_set if some value stop_gradient=True | ||
update_no_grad_set_by_stopgradient(block, no_grad_set) | ||
complete_outputs, backward_ops = prepare_grad_outputs( | ||
grad_outputs, outputs, state | ||
) | ||
with block: | ||
complete_outputs, backward_ops = prepare_grad_outputs( | ||
grad_outputs, outputs, state | ||
) | ||
|
||
inputs_set = ValueSet(inputs) | ||
stop_gradient_false_outputs = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.