Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract windows GPU #1538

Merged
merged 13 commits into from
Dec 19, 2019
Merged

Extract windows GPU #1538

merged 13 commits into from
Dec 19, 2019

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Dec 3, 2019

Why we need this PR?

Pick one

  • It adds a feature required to implement STFT for GPU

What happened in this PR?

  • added ExtractWindows kernel for GPU with horizontal and vertical layout

How is this tested?

  • C++ unit tests against naive implementation (some small batches with various parameters + size sweep to check for edge cases)

JIRA TASK: [DALI-1168]

@mzient mzient requested review from jantonguirao, JanuszL and a team December 3, 2019 18:02
@mzient
Copy link
Contributor Author

mzient commented Dec 3, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1015952]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1015952]: BUILD FAILED

@mzient
Copy link
Contributor Author

mzient commented Dec 4, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1017208]: BUILD STARTED

struct ExtractWindowsGPU<Dst, Src>::Impl : public ExtractWindowsGPUImpl<Dst, Src> {
};

template <typename Dst, typename Src>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned earlier, the CPU variant of extract windows can work on 2D inputs (e.g. stereo audio signal) as long as the temporal dimension is the inner-most. Would it be possible to do something similar here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the innermost dimension? Rather hard to do for the general case. For C=2 it should be doable. I'm not really convinced that this kernel is going to stay at all, so I wouldn't put too much effort into making it very flexible.

/// @remarks This function must be executed by all (or no) threads in a block!
template <int num_pages = 1, typename Dst, typename Src>
__device__ void ExtractWindowsBlock(
int first_window_idx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: from the first read perspective at least, the order of those arguments seems a bit random.

break;
}
}
float v = idx >= 0 && idx < length ? ConvertNorm<float>(src[idx]) * w : Src();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit non-intuitive that we are doing normalization here on the fly when doing type conversion. Even if it makes sense for most of the usages it is not really documented and it is well hidden in the CUDA kernel implementation.

My opinion is that we should probably leave type conversion to the decoder and let this kernel to work on Dst = Src

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it applies a windowing function during the extraction, it is not a normalization. But it deserves a similar documentation as ExtractHorizontalWindows.

blockIdx.x * kBlock, // first window index
dst, num_windows, stride, // output
src, length, // input
window, win_len, win_center, step, reflect); // windowing options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
window, win_len, win_center, step, reflect); // windowing options
window, win_len, win_center, win_step, reflect); // windowing options

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'd rename window_step to step in ExtractWindowsArgs - window_step might mean other things, e.g. step between samples when extracting dilated windows (e.g. for multi-channel data or just downsampling).

namespace kernels {
namespace signal {

struct ExtractWindowsGPUArgs : ExtractWindowsArgs {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct ExtractWindowsGPUArgs : ExtractWindowsArgs {
struct ExtractWindowsGPUArgs : public ExtractWindowsArgs {

static_assert(std::is_same<Dst, float>::value, "Output type must be float");
static_assert(
std::is_same<Src, float>::value ||
std::is_same<Src, int8_t>::value ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep the type conversion, we should document that when converting from int types to float, the range is normalized to [-1.0, 1.0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

ExtractWindowsArgs args;
args.window_length = window.empty() ? 55 : window.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Those tests could be parametrized. At the very least, those magic numbers could be declared constants at the top of the test

out_shape.set_tensor_shape(0, { out_win_length, total_windows });
}

while (xgrid > 0x10000 && xgrid > N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make 0x10000 a named constant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can, I just don't know if it really helps to have kMaxBlocks = 0x10000 defined in previous line and used just once...

const InListGPU<Src, 1> &input,
const InTensorGPU<float, 1> &window,
const ExtractWindowsBatchedArgs &args) {
(void)args;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it here only to pass the "unused" warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is.

@mzient mzient changed the title Extract windows gpu Extract windows GPU Dec 4, 2019
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1017208]: BUILD FAILED

@mzient
Copy link
Contributor Author

mzient commented Dec 9, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1023982]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1023982]: BUILD PASSED

};

struct BlockDesc {
int sample_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add some @brief to sample_idx as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

};

struct HorizontalBlockDesc {
int sample_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add some @brief to sample_idx as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

namespace signal {

struct ExtractWindowsBatchedArgs : ExtractWindowsArgs {
bool vertical = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some @brief for vertical ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

/**
* @brief If true, all outputs are concatenated.
*
* In case of vertical windows, tThe concatenated output will contain all first samples from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* In case of vertical windows, tThe concatenated output will contain all first samples from
* In case of vertical windows, the concatenated output will contain all first samples from

/**
* @brief Indicates that the output should be overallocated (or windows truncated) to this size.
*/
int padded_output_window = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it map to out_win_length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Neither name really reflects what it does... it's hard make a descriptive name for it, really.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can't figure the best name, at least use one everywhere.

@mzient mzient force-pushed the ExtractWindowsGPU branch 3 times, most recently from 46a5651 to 187e0a6 Compare December 16, 2019 12:03
@mzient
Copy link
Contributor Author

mzient commented Dec 16, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1034835]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1034835]: BUILD FAILED

@mzient
Copy link
Contributor Author

mzient commented Dec 16, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1035008]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1035008]: BUILD PASSED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient
Copy link
Contributor Author

mzient commented Dec 19, 2019

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1040589]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1040589]: BUILD PASSED

@mzient mzient merged commit 2e368b1 into NVIDIA:master Dec 19, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants