Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

make slice operator 20x faster on GPU #11124

Merged
merged 2 commits into from
Jun 2, 2018

Conversation

eric-haibin-lin
Copy link
Member

Description

The current slice kernel is slow on GPU as each thread in the warp accesses different memory address in the output.

@reminisce @haojin2 @safrooze

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here
import mxnet as mx
import time

a = mx.nd.ones((512, 8192), ctx=mx.gpu())
for i in range(10):
    b = a.slice(begin=(None, 4), end=(None, None))

mx.nd.waitall()
t0 = time.time()
for i in range(1000):
    b = a.slice(begin=(None, 4), end=(None, None))

mx.nd.waitall()
t1 = time.time()
print(t1 - t0)

Time on K80

  • before: 11.9892659187 (s)
  • after: 0.497007131577 (s)

@eric-haibin-lin eric-haibin-lin changed the title [WIP] make slice operator 20x faster on GPU make slice operator 20x faster on GPU Jun 2, 2018
@piiswrong piiswrong merged commit a068fae into apache:master Jun 2, 2018
eric-haibin-lin added a commit to eric-haibin-lin/mxnet that referenced this pull request Jun 11, 2018
* gpu slice kernel

* remove unused line
anirudh2290 pushed a commit that referenced this pull request Jun 11, 2018
* gpu slice kernel

* remove unused line
@ThomasDelteil ThomasDelteil mentioned this pull request Jun 17, 2018
7 tasks
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* gpu slice kernel

* remove unused line
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* gpu slice kernel

* remove unused line
@BiranLi
Copy link

BiranLi commented Sep 7, 2018

No acceleration on the V100.

@safrooze
Copy link
Contributor

safrooze commented Sep 7, 2018

@BiranLi How did you test this? I've seen 20x improvement on V100 using p3.2x instances.

@BiranLi
Copy link

BiranLi commented Sep 7, 2018

@safrooze Hi. What is the test environment on your side? What is the version of the mxnet used for comparison? What's your test code ?

@BiranLi
Copy link

BiranLi commented Sep 7, 2018

@safrooze I have tested in V100 between mxnet_0.12 and mxnet_1.3.
before: 0.09s
after:0.08s

@eric-haibin-lin
Copy link
Member Author

@BiranLi i'm using this test code:

import mxnet as mx
import time

a = mx.nd.ones((512, 8192), ctx=mx.gpu())
for i in range(10):
    b = a.slice(begin=(None, 4), end=(None, None))

mx.nd.waitall()
t0 = time.time()
for i in range(1000):
    b = a.slice(begin=(None, 4), end=(None, None))

mx.nd.waitall()
t1 = time.time()
print(t1 - t0)

What did you use? Are you on slack?

@eric-haibin-lin eric-haibin-lin deleted the slice-gpu branch September 7, 2018 21:46
@BiranLi
Copy link

BiranLi commented Sep 10, 2018

Yes, I was using the same test code. I get a similar result on both version mxnet(0.12 vs 1.3).

@eric-haibin-lin
Copy link
Member Author

Oh I didn't compare with mxnet 0.12. Maybe it was a regression between 0.12 and 1.2

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants