Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune Arithmetic Op launch specification #2137

Merged
merged 5 commits into from
Jul 22, 2020
Merged

Conversation

klecki
Copy link
Contributor

@klecki klecki commented Jul 21, 2020

Signed-off-by: Krzysztof Lecki klecki@nvidia.com

Why we need this PR?

Arithmetic Op is not fast enough, GPU was underutilized due to not big enough tiles.

What happened in this PR?

  • What solution was applied:
    Tile was made bigger, grid was made smaller to allow for few more iterations for given thread.
    Pointer usage was adjusted a bit, it helps with small inputs a bit.
    Benchmark was added

  • Affected modules and functionalities:
    Arithmetic Ops

  • Key points relevant for the review:

  • Validation and testing:
    Benchmark

  • Documentation (including examples):
    NA

OLD:

test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd44f6ae8>, [(1024, 1024)] * 256, '*') ... Throughput: 377.034 GB/s
Throughput: 375.006 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd44f6ae8>, [(16384, 1024)] * 64, '*') ... Throughput: 391.476 GB/s
Throughput: 405.331 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd44f6ae8>, [(400, 400)] * 64, '*') ... Throughput: 334.448 GB/s
Throughput: 353.607 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd3e962f0>, [(1024, 1024)] * 256, '*') ... Throughput: 490.778 GB/s
Throughput: 490.198 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd3e962f0>, [(16384, 1024)] * 64, '*') ... Throughput: 499.623 GB/s
Throughput: 507.385 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f5fd3e962f0>, [(400, 400)] * 64, '*') ... Throughput: 358.166 GB/s
Throughput: 402.577 GB/s
ok

----------------------------------------------------------------------
Ran 6 tests in 10.587s

OK

NEW:

test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c80bbae8>, [(1024, 1024)] * 256, '*') ... Throughput: 533.855 GB/s
Throughput: 531.948 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c80bbae8>, [(16384, 1024)] * 64 '*') ... Throughput: 576.293 GB/s
Throughput: 576.419 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'const'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c80bbae8>, [(400, 400)] * 64, '*') ... Throughput: 462.535 GB/s
Throughput: 498.008 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c7a5c2f0>, [(1024, 1024)] * 256, '*') ... Throughput: 571.668 GB/s
Throughput: 572.734 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c7a5c2f0>, [(16384, 1024)] * 64, '*') ... Throughput: 604.232 GB/s
Throughput: 603.718 GB/s
ok
test_operator_arithmetic_ops.test_arithmetic_ops_perf(('gpu', 'gpu'), (<class 'numpy.float32'>, <class 'numpy.float32'>), <function test_arithmetic_ops_perf.<locals>.<lambda> at 0x7f96c7a5c2f0>, [(400, 400)] * 64, '*') ... Throughput: 447.778 GB/s
Throughput: 459.841 GB/s
ok

----------------------------------------------------------------------
Ran 6 tests in 10.426s

OK

JIRA TASK: [Use DALI-1514 or NA]

Add benchmark, adjust test to new tile size

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Jul 21, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484299]: BUILD STARTED

for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(l[i], r);
*result = meta::impl(*l, r);
Copy link
Contributor

@mzient mzient Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Does using pointer arithmetic instead of indexing help in any way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems so with smaller inputs, with bigger ones it doesn't make much difference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, can we store the offset and step in variables to avoid repeating?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done I guess.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
for (int sample_id = 0; sample_id < TestConfig::batch_size; sample_id++) {
for (int extent_id = 0; extent_id < TestConfig::tiles_per_sample; extent_id++) {
int tile_id = sample_id * TestConfig::tiles_per_sample + extent_id;
tiles_cpu(tile_id)->desc = tile_descs[tile_id];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you fill tile_descs here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could but implementing that lambda would be probably a bit more of a hassle than a direct loop.

@@ -33,8 +33,12 @@ namespace dali {
template <ArithmeticOp op, typename Result, typename Input>
__device__ void ExecuteUnOp(Result *result, const Input *in, int64_t extent) {
using meta = arithm_meta<op, GPUBackend>;
result += blockIdx.x * blockDim.x + threadIdx.x;
in += blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract blockDim.x * gridDim.x to a variable?

for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(in[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seemed more readable before. Why the change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a small perf gain as well as correctness with int64 offset.

for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(l[i], r[i]);
*result = meta::impl(*l, *r);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same applies here. Why the change?

auto left = static_cast<const Left *>(tile.args[0]);
auto right = static_cast<const Right *>(tile.args[1]);
auto *output = static_cast<Result *>(tile.output);
const auto *left = static_cast<const Left *>(tile.args[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go with either

Suggested change
const auto *left = static_cast<const Left *>(tile.args[0]);
const Left *left = static_cast<const Left *>(tile.args[0]);

or

Suggested change
const auto *left = static_cast<const Left *>(tile.args[0]);
auto left = static_cast<const Left *>(tile.args[0]);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can go back to auto left

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@@ -44,8 +51,17 @@ __device__ void ExecuteUnOp(Result *result, const Input *in, int64_t extent) {
template <ArithmeticOp op, typename Result, typename Left, typename Right>
__device__ void ExecuteBinOp(Result *result, const Left *l, const Right *r, int64_t extent) {
using meta = arithm_meta<op, GPUBackend>;
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(l[i], r[i]);
uint32_t start_ofs = (blockDim.x) * blockIdx.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the previous implementation, you used int64_t. Can we use int64_t everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added by mistake, will revert to previous int64_t change.

@@ -55,8 +71,15 @@ __device__ void ExecuteBinOp(Result *result, const Left *l, const Right *r, int6
template <ArithmeticOp op, typename Result, typename Left, typename Right>
__device__ void ExecuteBinOp(Result *result, Left l, const Right *r, int64_t extent) {
using meta = arithm_meta<op, GPUBackend>;
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(l, r[i]);
uint32_t start_ofs = (blockDim.x) * blockIdx.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -66,16 +89,23 @@ __device__ void ExecuteBinOp(Result *result, Left l, const Right *r, int64_t ext
template <ArithmeticOp op, typename Result, typename Left, typename Right>
__device__ void ExecuteBinOp(Result *result, const Left *l, Right r, int64_t extent) {
using meta = arithm_meta<op, GPUBackend>;
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < extent; i += blockDim.x * gridDim.x) {
result[i] = meta::impl(l[i], r);
uint32_t start_ofs = (blockDim.x) * blockIdx.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Jul 21, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484540]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484540]: BUILD FAILED

@klecki
Copy link
Contributor Author

klecki commented Jul 21, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484876]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484876]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1484876]: BUILD PASSED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Jul 22, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1486961]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1486961]: BUILD FAILED

@klecki
Copy link
Contributor Author

klecki commented Jul 22, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1487392]: BUILD STARTED

@klecki
Copy link
Contributor Author

klecki commented Jul 22, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1487604]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1487392]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1487604]: BUILD PASSED

@klecki klecki merged commit 498a22e into NVIDIA:master Jul 22, 2020
@klecki klecki deleted the arithm-op-perf branch July 22, 2020 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants