Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resampling ND - ground work #1366

Merged
merged 5 commits into from Nov 8, 2019
Merged

Resampling ND - ground work #1366

merged 5 commits into from Nov 8, 2019

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Oct 11, 2019

Signed-off-by: Michal Zientkiewicz michalz@nvidia.com

Why we need this PR?

Pick one

  • It is intermediate step for Volume Resampling

What happened in this PR?

  • Number of dimensions in resampling kernels is now a template argument
  • Pass number moved from template argument to function argument (no need to specialize over that one)
  • Important: switched coordinate ordering; when unclear, the general scheme is: dim is a tensor-like dimensions (0 is outermost), axis is vector-like (0 is X).

JIRA TASK: [DALI-1075]

@mzient
Copy link
Contributor Author

mzient commented Oct 11, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [940918]: BUILD STARTED

dali/kernels/imgproc/resample/resampling_batch.cu Outdated Show resolved Hide resolved
* @remarks The function clamps input coordinates to fit in range defined by `in` dimensions.
* Scales can be negative to achieve flipping.
*/
template <typename Out, typename In, int n>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: rename n to spatial_ndim for consistency

const OutTensorCPU<SampleBlockInfo, 1> &sample_lookup) {
assert(sample_lookup.shape[0] >= total_blocks.pass[0] + total_blocks.pass[1]);
int blocks_in_all_passes = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea: surround this by #if DALI_DEBUG

dali/kernels/imgproc/resample/resampling_setup.cc Outdated Show resolved Hide resolved
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [940918]: BUILD PASSED

@@ -92,7 +87,7 @@ __global__ void BatchedSeparableResampleKernel(
}
break;
case ResamplingFilterType::Linear:
if (axis == 1) {
if (axis == spatial_ndim - 1) {
LinearHorz(x0, x1, y0, y1, origin, scale, sample_out, out_stride, sample_in, in_stride,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this Horz/Vert naming is a bit misleading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea: you could have some constants to make this more readable:
(e.g kHorizontalAxis = 0 and kHorizontalDim = spatial_ndim - 1)

@mzient
Copy link
Contributor Author

mzient commented Oct 14, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [943957]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [943957]: BUILD PASSED

@@ -81,6 +81,8 @@ struct ResamplingParams {
};

using ResamplingParams2D = std::array<ResamplingParams, 2>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using ResamplingParams2D = std::array<ResamplingParams, 2>;
using ResamplingParams2D = ResamplingParamsND<2>;

?

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Switch coordinate order in resampling from DHW (matrix rank) to XYZ (geometric).

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Change some separate scalar x,y to vectors in resampling kernel.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient mzient changed the title Rework number of dimensions to be a template argument. Resampling ND - ground work Nov 8, 2019
@mzient
Copy link
Contributor Author

mzient commented Nov 8, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [982102]: BUILD STARTED

Unlock int16, uint16 and int32 resize.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient
Copy link
Contributor Author

mzient commented Nov 8, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [982113]: BUILD STARTED

DeviceArray<uintptr_t, num_buffers> pointers;
DeviceArray<ptrdiff_t, num_buffers> offsets;
DeviceArray<Strides, num_buffers> strides;
DeviceArray<Shape, num_buffers> shapes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can align the the shape with the rest of variables.

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [982113]: BUILD PASSED

@mzient mzient merged commit 4d31efd into NVIDIA:master Nov 8, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants