Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.58, 62, 64, 70】Migrate some ops into pir #59230

Closed
wants to merge 13 commits into from

Conversation

longranger2
Copy link
Contributor

@longranger2 longranger2 commented Nov 21, 2023

PR types

Others

PR changes

APIs

Description

PIR API 推全升级
将如下算子迁移升级至 pir,并更新单测

  • Dirichlet(0/3):test_dirichlet_op.py单测暂不打开,因为暂时还没有想到好的办法去添加这个 self.python_api
    test/distribution/test_dirichlet_op.py未适配该单测

  • eigvals(19/19)

  • ELU(1/2):其中一个是用来检测error的

  • scaled_dot_product_attention

  • 新IR Python API适配升级 #58067

@longranger2
Copy link
Contributor Author

第70个这个不清楚要怎么进行修改呢?paddle.nn.functional.scaled_dot_product_attention

image

@luotao1 luotao1 added contributor External developers HappyOpenSource 快乐开源活动issue与PR labels Nov 22, 2023
@MarioLulab
Copy link
Contributor

scaled_dot_product_attention

可以先对应适配 flash_attention,然后在 scaled_dot_product_attention 和 flash_attention 的相关单测中打开 pir 检查

@@ -31,6 +31,7 @@ class TestDirichletOp(OpTest):

def setUp(self):
self.op_type = "dirichlet"
self.python_api = paddle.distribution.Dirichlet
self.alpha = np.array((1.0, 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python_api 需要是一个 callable 对象,其调用的结果是返回结果。比方说这里的 self.python_api = paddle.distribution.Dirichlet,如果调用只是实例化了一个 Dirichlet 对象;而 self.python_api = paddle.distribution.Dirichlet(...).sample,这个 self.python_api 被调用才是返回一个结果。这里单测的结果需要和 self._hypothesis_testing 作比较,所以 self.python_api 需要根据 self._hypothesis_testing 相应的计算类型进行对应的设置

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api

This comment was marked as off-topic.

This comment was marked as resolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧

@@ -152,9 +153,20 @@ def test_unpadded(self):
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)

# test static
@test_with_pir_api
def test_static_unpadded(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

想问下这里为何适配了 flash_attn_unpadded 的单测?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看错了😭

@@ -255,9 +271,30 @@ def test_all(self):
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)

# test static
def test_static_all(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要使用 @test_with_pir_api 装饰

Copy link
Contributor

@MarioLulab MarioLulab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/paddle/distribution/dirichlet.py 里的 _dirichlet 函数需要适配 pir mode

This comment was marked as resolved.

@@ -31,6 +31,7 @@ class TestDirichletOp(OpTest):

def setUp(self):
self.op_type = "dirichlet"
self.python_api = paddle.distribution.Dirichlet
self.alpha = np.array((1.0, 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api

@@ -32,15 +33,23 @@
)
class TestDirichlet(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestDirichlet 还是建议单独写一个针对 pir 的单测,因为此处旧 ir 单测与 pir 单测兼容起来比较麻烦,可读性较差。比如 TestDirichletPir
setUp 阶段:

def setUp(self):
    with paddle.pir_utils.IrGuard():
        self.program = paddle.static.Program()
        with paddle.static.program_guard(self.program):
            conc = paddle.static.data(
                'conc', self.concentration.shape, self.concentration.dtype
            )
            self._paddle_diric = paddle.distribution.Dirichlet(conc)
            self.feeds = {'conc': self.concentration}

单测写成如下形式:

def test_all(self): 
    with paddle.pir_utils.IrGuard():
         self._test_mean(self.program)
         self._test_variance(self.program)
         ...

@MarioLulab
Copy link
Contributor

辛苦解决一下冲突~

@longranger2
Copy link
Contributor Author

done

Comment on lines 143 to 156
@test_with_pir_api
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test_with_pir_api
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)
def test_mean(self):
with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

此处不宜使用 @test_with_pir_api,被 @test_with_pir_api 修饰的单测会在旧 ir 模式和 pir 模式下分别运行一次。因为该单测为 pir 单测,所以我们只需要在 Pir 模式下运行就好了。使用 paddle.pir_utils.IrGuard() 作为上下文管理器可以切换到 pir 模式进行组网

atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

Comment on lines 228 to 233
def test_all(self):
self._test_mean(self.program)
self._test_variance(self.program)
self._test_prob(self.program)
self._test_log_prob(self.program)
self._test_entropy(self.program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_all(self):
self._test_mean(self.program)
self._test_variance(self.program)
self._test_prob(self.program)
self._test_log_prob(self.program)
self._test_entropy(self.program)

该单测可以删除了,因为当前不存在 _test_mean 等函数

Comment on lines 106 to 108
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)

先复原对 test/distribution/test_dirichlet_op.py 文件的改动,并在 pr 描述里说明一下未适配该单测~

Copy link

paddle-ci-bot bot commented Jan 5, 2024

Sorry to inform you that 8fca8f9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link

paddle-ci-bot bot commented Jan 14, 2024

Sorry to inform you that 33253f2's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@YuanRisheng
Copy link
Contributor

@longranger2 辛苦重新提交代码解决一下代码冲突

@0x45f
Copy link
Contributor

0x45f commented Jan 24, 2024

相关API我们内部来进行适配吧~

Copy link

paddle-ci-bot bot commented Jan 25, 2024

Sorry to inform you that 9711cc9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@luotao1 luotao1 closed this Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants