Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Perf]Call cublas.Gemm instead of cublasGemmStridedBatched #1407

Merged
merged 2 commits into from
May 12, 2023

Conversation

Aurelius84
Copy link
Collaborator

@Aurelius84 Aurelius84 commented May 9, 2023

what's New?

复现代码:

#!/usr/bin/env python3
# Please set "export PYTHONPATH=${CINN_ROOT}/build/python:${PYTHONPATH}" first
import paddle
import unittest
import numpy as np
import cinn
from cinn.frontend import *
from cinn.common import *
from op_test import OpTest

class TestGroup(unittest.TestCase):
  def test_group(self):
    builder = NetBuilder("matmul")
    x_shape = [128, 128, 768]
    y_shape = [768, 768]

    x = builder.create_input(Float16(),x_shape, "x")
    y = builder.create_input(Float16(), y_shape, "y")
    out = builder.matmul(
            x, y, transpose_x=False, transpose_y=False)

    feed_list = [x, y]
    fetch_list = [out]

    prog = builder.build()

    feed_data = [OpTest.random(shape=var.shape(), dtype=var.type()) for var in feed_list]
    result = prog.build_and_get_output(DefaultNVGPUTarget(), feed_list, feed_data, fetch_list)

    result = [res.numpy(DefaultNVGPUTarget()) for res in result]
    for i in range(len(result)):
      info_str = fetch_list[i].name()
      info_str += ", shape=" + str(result[i].shape)
      info_str += ", dtype=" + str(result[i].dtype) + ":\n"
      print(info_str)

if __name__ == "__main__":
  unittest.main()

生效日志:
image

@paddle-bot
Copy link

paddle-bot bot commented May 9, 2023

Thanks for your contribution!

Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦顺便在python/tests/ops/test_matmul_op.py 下加上对应的测试用例吧

@Aurelius84
Copy link
Collaborator Author

麻烦顺便在python/tests/ops/test_matmul_op.py 下加上对应的测试用例吧

好的

@Aurelius84 Aurelius84 requested a review from thisjiang May 10, 2023 08:39
Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit 06424c2 into PaddlePaddle:develop May 12, 2023
jiahy0825 pushed a commit to jiahy0825/CINN that referenced this pull request May 25, 2023
…dle#1407)

* [Perf]Call cublas.Gemm instead of cublasGemmStridedBatched

* fix unittest
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants