Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MFCC CPU operator #1577

Merged
merged 17 commits into from
Jan 7, 2020
Merged

MFCC CPU operator #1577

merged 17 commits into from
Jan 7, 2020

Conversation

jantonguirao
Copy link
Contributor

@jantonguirao jantonguirao commented Dec 13, 2019

Why we need this PR?

  • It adds new feature: MFCC (Mel Frequency Cepstrum Coefficients) operator

What happened in this PR?

  • Explain solution of the problem, new feature added.
    Adds a new CPU operator MFCC that calculates the MFFCs from a mel spectrogram
  • What was changed, added, removed?
  • What is most important part that reviewers should focus on?
    Operator implementation
  • Was this PR tested? How?
    Python operator tests
  • Were docs and examples updated, if necessary?
    Doxygen, schema docstring, jupyter notebook

JIRA TASK: [DALI-1186]

Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1040371]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1040371]: BUILD FAILED

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao jantonguirao changed the title [WIP] MFCC CPU operator MFCC CPU operator Dec 19, 2019
@jantonguirao jantonguirao requested a review from a team December 19, 2019 16:14
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1040371]: BUILD PASSED


DALI_SCHEMA(MFCC)
.DocStr(R"code(Mel Frequency Cepstral Coefficiencs (MFCC).
Computes MFCCs from a mel spectrogram)code")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Computes MFCCs from a mel spectrogram)code")
Computes MFCCs from a mel spectrogram.)code")

R"code(Cepstral filtering (also known as `liftering`) coefficient.
If `lifter>0`, the MFCCs will be scaled according to the following formula::

MFFC[i] = MFCC[i] * (1 + sin(pi * (i + 1) / lifter)) * (lifter / 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguments that appeared after that have somehow broken formatting. Can you check what is wrong here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above.

0)
.AddOptionalArg("lifter",
R"code(Cepstral filtering (also known as `liftering`) coefficient.
If `lifter>0`, the MFCCs will be scaled according to the following formula::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing double colon starts a pre-formatted block.

Suggested change
If `lifter>0`, the MFCCs will be scaled according to the following formula::
If `lifter>0`, the MFCCs will be scaled according to the following formula:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, that's what @klecki suggested I can use to display equations and such

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would stick with it.

auto &req = kmgr_.Setup<DctKernel>(i, ctx, in_view, args_);
output_desc[0].shape.set_tensor_shape(i, req.output_shapes[0][0].shape);

if (in_view.shape[args_.axis] > max_length) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you check args_.axis is not out of the range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix


template <typename T, int Dims>
void ApplyLifter(const kernels::OutTensorCPU<T, Dims> &inout, int axis, const T* lifter_coeffs) {
auto* data = inout.data;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any validation of the axis value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is validation of it being >= 0 in the constructor but there is no check for the upper bound. I'll add that in SetupImpl

using Operator<Backend>::RunImpl;

void CalcLifterCoeffs(int64_t length) {
if (static_cast<int64_t>(lifter_coeffs_.size()) >= length || lifter_ == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (static_cast<int64_t>(lifter_coeffs_.size()) >= length || lifter_ == 0)
if (static_cast<int64_t>(lifter_coeffs_.size()) >= length || lifter_ == 0.0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunImpl already does that check. Do we need it here as well?
Also can we make CalcLifterCoeffs a free function and test it independently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

yield check_operator_mfcc_vs_python, device, batch_size, shape, \
axis, dct_type, lifter, n_mfcc, norm

#check_operator_mfcc_vs_python(device='cpu', batch_size=3, input_shape=(17,1), axis=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover, will remove

for dct_type in [1, 2, 3]:
for norm in [False] if dct_type == 1 else [True, False]:
for axis, n_mfcc, lifter, shape in \
[(0, 17, 0.0, (17, 1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have some tests for invalid arguments to check if it fails. I think we don't have many test that triggers asserts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll add some

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1044580]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1044580]: BUILD PASSED

Copy link
Contributor

@klecki klecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the notebook (I know some of that are not from this PR):
This is often call spectral leakage -> This is often called spectral leakage

We can calculate a mel spectrogram in decibels by using the following DALI pipeline. - I think it would be nice to describe how we do it in DALI, or rather what does this pipeline do - that is what the sequence of applied operators means in regard to what we have described above.


args_.normalize = spec.GetArgument<bool>("normalize");
if (args_.normalize) {
DALI_ENFORCE(args_.dct_type != 1, "Ortho-normalization is not supported for DCT type I");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nitpick :P

Suggested change
DALI_ENFORCE(args_.dct_type != 1, "Ortho-normalization is not supported for DCT type I");
DALI_ENFORCE(args_.dct_type != 1, "Ortho-normalization is not supported for DCT type I.");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix

auto* data = inout.data;
auto shape = inout.shape;
auto strides = kernels::GetStrides(shape);
kernels::ForAxis(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this can be optimized a bit, don't know if it's worth the effort, but will probably yield better memory access pattern.
You can have a variant of ForAxis for a case like Dims=5, axis=2 and do soemthing like:

for (x0 in Dim0)
  for (x1 in Dim1):
    // we're on our target axis, now we will use the same lifter_coefficient for neighbour
    // elements, so instead of invoking this lambda and iterating in Dim2 by jumping around
    // the data, calculate it in groups for coefficient 0, than 1, etc
   for (x2 in Dim2)
     multiply all Dim3 * Dim4 elements by lifter[x2]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in certain configurations this could be optimized for a better access pattern. However, ForAxis is rather a general utility and this would be an optimization that is specific for this particular case (because we can reuse the lifter coefficient). We could write something custom for a certain layout but I think it would be best to keep simplicity/generality here unless we know that there is a real performance problem here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Comment on lines 79 to 81
If `lifter>0`, the MFCCs will be scaled according to the following formula:

`MFFC[i] = MFCC[i] * (1 + sin(pi * (i + 1) / lifter)) * (lifter / 2)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use the :: here for the formula?

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1056217]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1056217]: BUILD PASSED

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1060544]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1060544]: BUILD PASSED

@jantonguirao jantonguirao merged commit 7421e77 into NVIDIA:master Jan 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants