Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cast operator performance #3783

Merged
merged 6 commits into from
May 4, 2022
Merged

Conversation

MirazSpecial
Copy link
Contributor

@MirazSpecial MirazSpecial commented Apr 2, 2022

Signed-off-by: Konrad Litwiński klitwinski41418@gmail.com

Category:

Refactoring (Redesign of existing code that doesn't affect functionality)

Description:

Work's main motivation was to improve throughput for small batch sizes of data for Cast.

Originally running Cast kernel (BatchedCastKernel) required copying two arrays to GPU -

  • samples array (with descriptors of placement of samples in memory) - it used 8 * samples_number bytes.
  • blocks array (with descriptors of what should be done for each threads block) - it used 20 * number_of_blocks bytes
    As there was a linear relationship between data size and number of blocks (number of blocks was around 1024 times smaller then data size) copying second array was a big cost of running Cast kernel.

The idea of this optimization is to instead of copying block array, create an array with information of how big the samples are, and which block is the first one to parse every sample, copy it to GPU and then - in the kernel - calculate which sample should the block work on. To calculate that efficiently we use binary search over sample descriptors.

Additional information:

For image size 1000x1000 we achieved following improvement
image

Affected modules and functionalities:

cast.cuh
cast.cu

Key points relevant for the review:

The key changes are in the newly added BinSearchCastKernel kernel.

Checklist

Tests

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

@MirazSpecial MirazSpecial changed the title Cast operator optimization using binary searc Cast operator optimization using binary search Apr 2, 2022
Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
@MirazSpecial MirazSpecial changed the title Cast operator optimization using binary search Improve cast operator performance Apr 2, 2022
@mzient mzient self-assigned this Apr 4, 2022
dali/kernels/common/cast.cuh Outdated Show resolved Hide resolved
@szalpal
Copy link
Member

szalpal commented Apr 4, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4381653]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4381653]: BUILD FAILED

Co-authored-by: Michał Zientkiewicz <mzient@gmail.com>
const CastSampleBlockDesc *params,
int nsamples, int block_volume_scale) {
int i = 0;
for (int jump = (1 << (32 - __clz(nsamples) - 1)); jump; jump >>= 1) {
Copy link
Contributor

@mzient mzient Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calculating (1 << (32 - __clz(nsamples) - 1)) outside and passing it as a kernel parameter? You can use ilog2 function. I'm not saying this is mandatory, but I'm curious if that would yield a measurable change in performance (one way or the other).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing (1 << (32 - __clz(nsamples) - 1)) as a paramter would mean adding another fifth kernel parameter (as nsamples needs to be passed as it's used in another place in the kernel).

As for performance, moving this calculation outside the kernel doesn't change performance in any significant way. AFAIK whole binary search has almost no impact on performance (I tried removing it and choosing random block to parse and it didn't improve throughput).

@szalpal
Copy link
Member

szalpal commented Apr 13, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4500452]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4500452]: BUILD FAILED

Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
@szalpal
Copy link
Member

szalpal commented Apr 26, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4680741]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4680741]: BUILD FAILED

@klecki
Copy link
Contributor

klecki commented Apr 26, 2022

Lint complaining:

#14 58.59 /opt/dali/dali/kernels/common/cast.cuh:30:  At least two spaces is best between code and comments  [whitespace/comments] [2]
#14 58.59 /opt/dali/dali/kernels/common/cast.cuh:58:  At least two spaces is best between code and comments  [whitespace/comments] [2]

@szalpal
Copy link
Member

szalpal commented Apr 26, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4683075]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4683075]: BUILD FAILED

Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
@szalpal
Copy link
Member

szalpal commented Apr 27, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4687549]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4687549]: BUILD FAILED

@klecki
Copy link
Contributor

klecki commented Apr 27, 2022

#14 251.6 /opt/dali/dali/kernels/common/cast.cuh:39:45: error: comparison of integers of different signs: 'int' and 'unsigned int' [-Werror,-Wsign-compare]
#14 251.6   for (int x = threadIdx.x + block_start; x < block_end; x += blockDim.x) {
#14 251.6                                           ~ ^ ~~~~~~~~~

(This is detected via clang-only build, it has more thorough error checking in CUDA code).

Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
@klecki
Copy link
Contributor

klecki commented May 2, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4731433]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4731433]: BUILD FAILED

@szalpal
Copy link
Member

szalpal commented May 4, 2022

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4747662]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4747662]: BUILD PASSED

@szalpal szalpal merged commit f492132 into NVIDIA:main May 4, 2022
cyyever pushed a commit to cyyever/DALI that referenced this pull request May 13, 2022
* Use binary search to find the sample to process
* Extracting params to CastSampleBlockDesc

Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
cyyever pushed a commit to cyyever/DALI that referenced this pull request Jun 7, 2022
* Use binary search to find the sample to process
* Extracting params to CastSampleBlockDesc

Signed-off-by: Konrad Litwiński <klitwinski41418@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants