Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the DCT GPU kernel. #2471

Merged
merged 7 commits into from
Nov 20, 2020
Merged

Conversation

banasraf
Copy link
Collaborator

Why we need this PR?

  • Refactoring to improve the performance of the GPU DCT kernel. In the current state it's extremely slow for the transform done on the inner axis, so I handle this case separately. Also the existing CUDA kernel was slightly optimized.

What happened in this PR?

  • What solution was applied:
    The case with the transform done over the inner axis is handled with a separate CUDA kernel. The existing kernel was optimized by employing shared memory.
  • Affected modules and functionalities:
    GPU DCT kernel.
  • Key points relevant for the review:
    A new CUDA kernel and the changes in the old one.
  • Validation and testing:
    Existing tests still apply. I added a performance test.
  • Documentation (including examples):
    N/A

JIRA TASK: DALI-1690

Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
@banasraf
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1805832]: BUILD STARTED

@JanuszL
Copy link
Contributor

JanuszL commented Nov 17, 2020

What is the speed now, and what used to be before that optimization?

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1805832]: BUILD FAILED

__global__ void ApplyDctInner(const typename Dct1DGpu<OutputType, InputType>::SampleDesc *samples,
const BlockSetupInner::BlockDesc *blocks,
const float *lifter_coeffs) {
extern __shared__ char *shm[];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extern __shared__ char *shm[];
extern __shared__ char shm[];

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
struct BlockDesc {
int64_t sample_idx;
int64_t frame_start;
int64_t frames_num;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int64_t frames_num;
int64_t num_frames;

or

Suggested change
int64_t frames_num;
int64_t frame_count;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
@banasraf
Copy link
Collaborator Author

!build

@banasraf
Copy link
Collaborator Author

banasraf commented Nov 17, 2020

@JanuszL

What is the speed now, and what used to be before that optimization?

For the planar layout it's ~550 GFLOPS -> ~630 GFLOPS and for interleaved it's ~30 GFLOPS -> ~315 GFLOPS

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1806677]: BUILD STARTED

Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
@@ -40,7 +40,7 @@ def define_graph(self):
test_data_root = get_dali_extra_path()
good_path = 'db/single'
missnamed_path = 'db/single/missnamed'
test_good_path = {'jpeg', 'mixed', 'png', 'tiff', 'pnm', 'bmp', 'jpeg2k'}
test_good_path = {'jpeg2k'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

???

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eeeh, good catch. I've committed too much

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1806677]: BUILD FAILED

Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
Signed-off-by: Rafal <Banas.Rafal97@gmail.com>
@banasraf
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1809572]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1809572]: BUILD PASSED

@banasraf banasraf merged commit d2f08b3 into NVIDIA:master Nov 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants