-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.58, 62, 64, 70】Migrate some ops into pir #59230
Conversation
可以先对应适配 flash_attention,然后在 scaled_dot_product_attention 和 flash_attention 的相关单测中打开 pir 检查 |
@@ -31,6 +31,7 @@ class TestDirichletOp(OpTest): | |||
|
|||
def setUp(self): | |||
self.op_type = "dirichlet" | |||
self.python_api = paddle.distribution.Dirichlet | |||
self.alpha = np.array((1.0, 2.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python_api 需要是一个 callable 对象,其调用的结果是返回结果。比方说这里的 self.python_api = paddle.distribution.Dirichlet
,如果调用只是实例化了一个 Dirichlet 对象;而 self.python_api = paddle.distribution.Dirichlet(...).sample
,这个 self.python_api 被调用才是返回一个结果。这里单测的结果需要和 self._hypothesis_testing 作比较,所以 self.python_api 需要根据 self._hypothesis_testing 相应的计算类型进行对应的设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧
@@ -152,9 +153,20 @@ def test_unpadded(self): | |||
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 | |||
) | |||
|
|||
# test static | |||
@test_with_pir_api | |||
def test_static_unpadded(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
想问下这里为何适配了 flash_attn_unpadded 的单测?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看错了😭
@@ -255,9 +271,30 @@ def test_all(self): | |||
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 | |||
) | |||
|
|||
# test static | |||
def test_static_all(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要使用 @test_with_pir_api
装饰
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/paddle/distribution/dirichlet.py 里的 _dirichlet 函数需要适配 pir mode
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
@@ -31,6 +31,7 @@ class TestDirichletOp(OpTest): | |||
|
|||
def setUp(self): | |||
self.op_type = "dirichlet" | |||
self.python_api = paddle.distribution.Dirichlet | |||
self.alpha = np.array((1.0, 2.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api
@@ -32,15 +33,23 @@ | |||
) | |||
class TestDirichlet(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestDirichlet 还是建议单独写一个针对 pir 的单测,因为此处旧 ir 单测与 pir 单测兼容起来比较麻烦,可读性较差。比如 TestDirichletPir
其 setUp
阶段:
def setUp(self):
with paddle.pir_utils.IrGuard():
self.program = paddle.static.Program()
with paddle.static.program_guard(self.program):
conc = paddle.static.data(
'conc', self.concentration.shape, self.concentration.dtype
)
self._paddle_diric = paddle.distribution.Dirichlet(conc)
self.feeds = {'conc': self.concentration}
单测写成如下形式:
def test_all(self):
with paddle.pir_utils.IrGuard():
self._test_mean(self.program)
self._test_variance(self.program)
...
辛苦解决一下冲突~ |
done |
@test_with_pir_api | ||
def test_mean(self): | ||
with paddle.static.program_guard(self.program): | ||
[out] = self.executor.run( | ||
self.program, | ||
feed=self.feeds, | ||
fetch_list=[self._paddle_diric.mean], | ||
) | ||
np.testing.assert_allclose( | ||
out, | ||
scipy.stats.dirichlet.mean(self.concentration), | ||
rtol=RTOL.get(str(self.concentration.dtype)), | ||
atol=ATOL.get(str(self.concentration.dtype)), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_with_pir_api | |
def test_mean(self): | |
with paddle.static.program_guard(self.program): | |
[out] = self.executor.run( | |
self.program, | |
feed=self.feeds, | |
fetch_list=[self._paddle_diric.mean], | |
) | |
np.testing.assert_allclose( | |
out, | |
scipy.stats.dirichlet.mean(self.concentration), | |
rtol=RTOL.get(str(self.concentration.dtype)), | |
atol=ATOL.get(str(self.concentration.dtype)), | |
) | |
def test_mean(self): | |
with paddle.pir_utils.IrGuard(): | |
with paddle.static.program_guard(self.program): | |
[out] = self.executor.run( | |
self.program, | |
feed=self.feeds, | |
fetch_list=[self._paddle_diric.mean], | |
) | |
np.testing.assert_allclose( | |
out, | |
scipy.stats.dirichlet.mean(self.concentration), | |
rtol=RTOL.get(str(self.concentration.dtype)), | |
atol=ATOL.get(str(self.concentration.dtype)), | |
) |
此处不宜使用 @test_with_pir_api
,被 @test_with_pir_api
修饰的单测会在旧 ir 模式和 pir 模式下分别运行一次。因为该单测为 pir 单测,所以我们只需要在 Pir 模式下运行就好了。使用 paddle.pir_utils.IrGuard()
作为上下文管理器可以切换到 pir 模式进行组网
atol=ATOL.get(str(self.concentration.dtype)), | ||
) | ||
|
||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
atol=ATOL.get(str(self.concentration.dtype)), | ||
) | ||
|
||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
atol=ATOL.get(str(self.concentration.dtype)), | ||
) | ||
|
||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
atol=ATOL.get(str(self.concentration.dtype)), | ||
) | ||
|
||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
def test_all(self): | ||
self._test_mean(self.program) | ||
self._test_variance(self.program) | ||
self._test_prob(self.program) | ||
self._test_log_prob(self.program) | ||
self._test_entropy(self.program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_all(self): | |
self._test_mean(self.program) | |
self._test_variance(self.program) | |
self._test_prob(self.program) | |
self._test_log_prob(self.program) | |
self._test_entropy(self.program) |
该单测可以删除了,因为当前不存在 _test_mean
等函数
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | ||
self.sample_shape | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | |
self.sample_shape | |
) |
先复原对 test/distribution/test_dirichlet_op.py
文件的改动,并在 pr 描述里说明一下未适配该单测~
Sorry to inform you that 8fca8f9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Sorry to inform you that 33253f2's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@longranger2 辛苦重新提交代码解决一下代码冲突 |
相关API我们内部来进行适配吧~ |
Sorry to inform you that 9711cc9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR types
Others
PR changes
APIs
Description
PIR API 推全升级
将如下算子迁移升级至 pir,并更新单测
Dirichlet(0/3):test_dirichlet_op.py单测暂不打开,因为暂时还没有想到好的办法去添加这个 self.python_api
test/distribution/test_dirichlet_op.py未适配该单测
eigvals(19/19)
ELU(1/2):其中一个是用来检测error的
scaled_dot_product_attention
新IR Python API适配升级 #58067