-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU variant of Spectrogram operator #1786
Conversation
Check out this pull request on You'll be able to see Jupyter notebook diff and discuss changes. Powered by ReviewNB. |
Add tests for GPU spectrogram op. Extend tests for custom window functions. Switch to GPU spectrogram in Jupyter example. Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
da79b99
to
0c729ed
Compare
!build |
CI MESSAGE: [1165994]: BUILD STARTED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't GPU STFT kernel provides a build-in toDecibel conversion as well? How it is exposed here?
CI MESSAGE: [1165994]: BUILD PASSED |
The conversion to decibels requires a reference which, by default, is the maximum of the signal - we don't have min/max for GPU yet. When we do (which we'll have to to implement ToDecibels anyway), we can insert additional call to the max kernel in Spectrogram and enable this fusion at the operator level - it should be quite easy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, except some minor comments
make_string("Invalid window length: ", args.window_length)); | ||
DALI_ENFORCE(args.window_step > 0, make_string("Invalid window step: ", args.window_step)); | ||
DALI_ENFORCE(args.window_length <= args.nfft, | ||
make_string("`window_length` must not exceed ransform size (`nfft`). Got nfft = ", args.nfft, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_string("`window_length` must not exceed ransform size (`nfft`). Got nfft = ", args.nfft, | |
make_string("`window_length` must not exceed transform size (`nfft`). Got nfft = ", args.nfft, |
auto &in = ws.InputRef<GPUBackend>(0); | ||
KernelContext ctx; | ||
ctx.gpu.stream = ws.stream(); | ||
auto in_shape = in.shape().to_static<1>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add some validation of the input (e.g. number of dimensions) and give a meaningful error msg to the user if it is not as expected
ctx.gpu.stream = ws.stream(); | ||
auto in_shape = in.shape().to_static<1>(); | ||
auto req = kmgr.Setup<SpectrogramGPU>(0, ctx, in_shape, args); | ||
CopyWindowToDevice(ctx.gpu.stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just copying the window to the GPU during construction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if the device is properly set there - and I don't have a stream, so I'd have to resort to cudaDeviceSynchronize
!build |
CI MESSAGE: [1168412]: BUILD STARTED |
Added reshape function for TensorListView. Fixed const-correctness in TensorListView construction with double pointer. Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
f1fe27e
to
79221d6
Compare
!build |
CI MESSAGE: [1168436]: BUILD STARTED |
CI MESSAGE: [1168436]: BUILD PASSED |
Add tests for GPU spectrogram op.
Extend tests for custom window functions.
Switch to GPU spectrogram in Jupyter example.
Signed-off-by: Michal Zientkiewicz michalz@nvidia.com
Why we need this PR?
Pick one, remove the rest
What happened in this PR?
Fill relevant points, put NA otherwise. Replace anything inside []
JIRA TASK: DALI-1168 (followup)