Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized variant of CMN for HWC to CHW case #4972

Merged
merged 10 commits into from
Aug 7, 2023

Conversation

klecki
Copy link
Contributor

@klecki klecki commented Aug 2, 2023

Category: New feature, Refactoring

Description:

Add an optimized version of SliceFlipNormalize kernel for the HWC to CHW layout switch.
The kernel has several variants that allow for:

  • cropping (only X-dimension is relevant, as Y-dimension is done via tiling)
  • mirroring X coordinate (as required by Crop Mirror Normalize operator)
  • padding channel dimension
    It assumes uint8_t inputs and allows for float16 and float32 outputs.

The algorithm is described in the docstring.
Additionally, due to the linear tiling, the bin-search for tile index is ported over from the Cast kernel.
Due to time constraints I did some code duplication, it may be worth to generalize this approach
for more kernels in a follow-up PR.

The CropMirrorNormalize operator setup is generalized to support the new and old versions
of the Slice kernel (most notably their setup).
Selection of appropriate implementation based on the inputs and parameters was added.

Testing is done via Python layer for simplicity.

The pure kernel (disregarding the setups) achieves 2.6 TB/s vs the 1.6 TB/s of the previous variant.

Simple benchmark utilizing DALI pipeline gives 2TB/s for the new one vs 1.5 TB/s for the old one, note that here we are self restricted by the previous iteration, overlapping the compute with training may help further.

Additional information:

Affected modules and functionalities:

New slice kernel, CropMirrorNormalize op.

Key points relevant for the review:

Kernel impl

Tests:

Existing operator tests + new Python tests focusing on the parameters used for this kernel variant.

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

@klecki
Copy link
Contributor Author

klecki commented Aug 2, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9209468]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9209468]: BUILD FAILED

@klecki
Copy link
Contributor Author

klecki commented Aug 2, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9209664]: BUILD STARTED

Integrate it into current CropMirrorNormalize,
generalize the setup parts of the operator to allow for selection of the
optimized implementations.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9209664]: BUILD FAILED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Aug 2, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9210806]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9210806]: BUILD FAILED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Aug 3, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9224809]: BUILD STARTED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Aug 3, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9224890]: BUILD STARTED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki klecki marked this pull request as ready for review August 3, 2023 11:51
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9224890]: BUILD PASSED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Aug 3, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9228248]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9228248]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9228248]: BUILD PASSED

aligned_tile[idx * 4 + 2] = in.z;
aligned_tile[idx * 4 + 3] = in.w;
}
int64_t processed_in_main = (left_after_prologue / 4) * 4;
Copy link
Contributor

@mzient mzient Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One instruction instead of 4 (signed division by a power of 2 requires three instructions).

Suggested change
int64_t processed_in_main = (left_after_prologue / 4) * 4;
int64_t processed_in_main = left_after_prologue & -4;

https://godbolt.org/z/GcWvxTPoW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

idx < end_x / kStaticChannels; idx += blockDim.x, base_x += blockDim.x) {
// TODO(klecki): forceinline device function
int64_t out_offset;
if constexpr (enable_mirror) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, not having those ifs if we do not use them actually helped.

for (int64_t idx = threadIdx.x + start_x / kStaticChannels, base_x = threadIdx.x;
idx < end_x / kStaticChannels; idx += blockDim.x, base_x += blockDim.x) {
int64_t out_offset;
if constexpr (enable_mirror) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it help?


float *tile_row = tile;

for (int y = y_start; y < y_end; y++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious with the inner block-strided loop. If the slice is narrow (< 4*blockDim.x), then many threads will do nothing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We include the channel dimension here, so the narrowest slice that utilizes the whole block is about 44 pixels. It's not that narrow.

Comment on lines 157 to 160
), DALI_FAIL(make_string("Not supported channel dimension:", channel_dim_idx_));); // NOLINT
), DALI_FAIL(make_string("Not supported number of spatial dimensions:", spatial_ndim_));); // NOLINT
), DALI_FAIL(make_string("Not supported output type:", output_type_));); // NOLINT
), DALI_FAIL(make_string("Not supported input type:", input_type_));); // NOLINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:

Suggested change
), DALI_FAIL(make_string("Not supported channel dimension:", channel_dim_idx_));); // NOLINT
), DALI_FAIL(make_string("Not supported number of spatial dimensions:", spatial_ndim_));); // NOLINT
), DALI_FAIL(make_string("Not supported output type:", output_type_));); // NOLINT
), DALI_FAIL(make_string("Not supported input type:", input_type_));); // NOLINT
), DALI_FAIL(make_string("Unsupported channel dimension:", channel_dim_idx_));); // NOLINT
), DALI_FAIL(make_string("Unsupported number of spatial dimensions:", spatial_ndim_));); // NOLINT
), DALI_FAIL(make_string("Unsupported output type:", output_type_));); // NOLINT
), DALI_FAIL(make_string("Unsupported input type:", input_type_));); // NOLINT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// const auto &req = k.Setup(ctx, sh, cargs);
// // k.test();
auto cargs = make_cspan(args);
auto &req = kmgr_.Setup<Kernel>(0, ctx, sh, make_cspan(args));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto &req = kmgr_.Setup<Kernel>(0, ctx, sh, make_cspan(args));
auto &req = kmgr_.Setup<Kernel>(0, ctx, sh, cargs);

? Otherwise cargs is unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

Copy link
Contributor

@mzient mzient left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it, so ✔️ , however, some parts need a follow-up.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Aug 7, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9266334]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9266334]: BUILD PASSED

@klecki klecki merged commit dbb79d4 into NVIDIA:main Aug 7, 2023
5 checks passed
stiepan pushed a commit that referenced this pull request Aug 8, 2023
Add an optimized version of SliceFlipNormalize kernel for the HWC to CHW layout switch.
The kernel has several variants that allow for:
* cropping (only X-dimension is relevant, as Y-dimension is done via tiling)
* mirroring X coordinate (as required by Crop Mirror Normalize operator)
* padding channel dimension
It assumes uint8_t inputs and allows for float16 and float32 outputs.

The algorithm is described in the docstring.
Additionally, due to the linear tiling, the bin-search for tile index is ported over from the Cast kernel.

The CropMirrorNormalize operator setup is generalized to support the new and old versions 
of the Slice kernel (most notably their setup).
Selection of appropriate implementation based on the inputs and parameters was added.

Testing is done via Python layer for simplicity.

The pure kernel (disregarding the setups) achieves 2.6 TB/s vs the 1.6 TB/s of the previous variant.

Simple benchmark utilizing DALI pipeline gives 2TB/s for the new one vs 1.5 TB/s for the old one, 
note that here we are self restricted by the previous iteration, overlapping the compute with
training may help further.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
JanuszL pushed a commit to JanuszL/DALI that referenced this pull request Oct 13, 2023
Add an optimized version of SliceFlipNormalize kernel for the HWC to CHW layout switch.
The kernel has several variants that allow for:
* cropping (only X-dimension is relevant, as Y-dimension is done via tiling)
* mirroring X coordinate (as required by Crop Mirror Normalize operator)
* padding channel dimension
It assumes uint8_t inputs and allows for float16 and float32 outputs.

The algorithm is described in the docstring.
Additionally, due to the linear tiling, the bin-search for tile index is ported over from the Cast kernel.

The CropMirrorNormalize operator setup is generalized to support the new and old versions 
of the Slice kernel (most notably their setup).
Selection of appropriate implementation based on the inputs and parameters was added.

Testing is done via Python layer for simplicity.

The pure kernel (disregarding the setups) achieves 2.6 TB/s vs the 1.6 TB/s of the previous variant.

Simple benchmark utilizing DALI pipeline gives 2TB/s for the new one vs 1.5 TB/s for the old one, 
note that here we are self restricted by the previous iteration, overlapping the compute with
training may help further.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants