-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable support for different layouts in the MelFilterBank GPU Op #2620
Changes from 2 commits
eb12e10
f714b22
e3bc841
cb1f0c8
6483b2c
6fc1b7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,38 +44,37 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> { | |
public: | ||
template <typename MelScale> | ||
Impl(MelScale mel_scale, const MelFilterBankArgs &args) | ||
: MelFilterImplBase<T, Dims>(mel_scale, args) { | ||
intervals_.resize(fftbin_size_, -1); | ||
: MelFilterImplBase<T, Dims>(mel_scale, args) { | ||
double mel = mel_low_ + mel_delta_; | ||
|
||
int64_t fftbin = fftbin_start_; | ||
double f = fftbin * hz_step_; | ||
|
||
int last_interval = args_.nfilter; | ||
for (int64_t interval = 0; interval <= last_interval; interval++, mel += mel_delta_) { | ||
if (interval == last_interval) { | ||
mel = mel_high_; | ||
auto nfilter = args.nfilter; | ||
assert(args.axis == Dims - 1 || args.axis == Dims - 2); | ||
if (args.axis == Dims - 2) { | ||
intervals_.resize(fftbin_size_, -1); | ||
int fftbin = fftbin_start_; | ||
double f = fftbin * hz_step_; | ||
for (int interval = 0; interval < nfilter + 1; interval++, mel += mel_delta_) { | ||
double freq = mel_scale.mel_to_hz(interval == nfilter ? mel_high_ : mel); | ||
for (; fftbin <= fftbin_end_ && f < freq; fftbin++, f = fftbin * hz_step_) { | ||
intervals_[fftbin] = interval; | ||
} | ||
} | ||
double freq = mel_scale.mel_to_hz(mel); | ||
for (; fftbin <= fftbin_end_ && f < freq; fftbin++, f = fftbin * hz_step_) { | ||
intervals_[fftbin] = interval; | ||
} else { // args.axis == Dims - 1 | ||
interval_ends_.resize(nfilter + 2); | ||
interval_ends_[0] = fftbin_start_; | ||
interval_ends_[nfilter + 1] = fftbin_end_ + 1; | ||
for (int interval = 1; interval < nfilter + 1; interval++, mel += mel_delta_) { | ||
double freq = mel_scale.mel_to_hz(mel); | ||
interval_ends_[interval] = std::ceil(freq / hz_step_); | ||
} | ||
} | ||
} | ||
|
||
void Compute(T* out, const T* in, int64_t nwindows, | ||
int64_t out_stride = -1, int64_t in_stride = -1) { | ||
if (out_stride <= 0) | ||
out_stride = nwindows; | ||
|
||
if (in_stride <= 0) | ||
in_stride = nwindows; | ||
|
||
void ComputeFreqMajor(T* out, const T* in, int64_t nwindows) { | ||
int nfilter = args_.nfilter; | ||
|
||
std::memset(out, 0, sizeof(T) * nfilter * nwindows); | ||
for (int64_t fftbin = fftbin_start_; fftbin <= fftbin_end_; fftbin++) { | ||
auto *in_row_start = in + fftbin * in_stride; | ||
auto *in_row_start = in + fftbin * nwindows; | ||
auto filter_up = intervals_[fftbin]; | ||
auto weight_up = T(1) - weights_down_[fftbin]; | ||
auto filter_down = filter_up - 1; | ||
|
@@ -84,7 +83,7 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> { | |
if (filter_down >= 0) { | ||
if (args_.normalize) | ||
weight_down *= norm_factors_[filter_down]; | ||
auto *out_row_start = out + filter_down * out_stride; | ||
auto *out_row_start = out + filter_down * nwindows; | ||
for (int t = 0; t < nwindows; t++) { | ||
out_row_start[t] += weight_down * in_row_start[t]; | ||
} | ||
|
@@ -93,16 +92,43 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> { | |
if (filter_up >= 0 && filter_up < nfilter) { | ||
if (args_.normalize) | ||
weight_up *= norm_factors_[filter_up]; | ||
auto *out_row_start = out + filter_up * out_stride; | ||
auto *out_row_start = out + filter_up * nwindows; | ||
for (int t = 0; t < nwindows; t++) { | ||
out_row_start[t] += weight_up * in_row_start[t]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void ComputeTimeMajor(T* out, const T* in, int64_t nwindows) { | ||
int nfilter = args_.nfilter; | ||
for (int t = 0; t < nwindows; t++) { | ||
const T *in_row = in + t * fftbin_size_; | ||
for (int m = 0; m < nfilter; m++) { | ||
T val = 0; | ||
int fftbin = interval_ends_[m]; | ||
int f1 = interval_ends_[m + 1]; | ||
int f2 = interval_ends_[m + 2]; | ||
for (; fftbin < f1; ++fftbin) { | ||
auto weight_up = T(1) - weights_down_[fftbin]; | ||
if (args_.normalize) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move it after the loops as |
||
weight_up *= norm_factors_[m]; | ||
val += in_row[fftbin] * weight_up; | ||
} | ||
for (; fftbin < f2; ++fftbin) { | ||
auto weight_down = weights_down_[fftbin]; | ||
if (args_.normalize) | ||
weight_down *= norm_factors_[m]; | ||
val += in_row[fftbin] * weight_down; | ||
} | ||
*out++ = val; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
std::vector<int> intervals_; | ||
std::vector<int> interval_ends_; | ||
USE_MEL_FILTER_IMPL_MEMBERS(T, Dims); | ||
}; | ||
|
||
|
@@ -113,15 +139,14 @@ template <typename T, int Dims> | |
MelFilterBankCpu<T, Dims>::~MelFilterBankCpu() = default; | ||
|
||
template <typename T, int Dims> | ||
KernelRequirements MelFilterBankCpu<T, Dims>::Setup( | ||
KernelContext &context, | ||
const InTensorCPU<T, Dims> &in, | ||
const MelFilterBankArgs &original_args) { | ||
auto args = original_args; | ||
KernelRequirements MelFilterBankCpu<T, Dims>::Setup(KernelContext &context, | ||
const InTensorCPU<T, Dims> &in, | ||
const MelFilterBankArgs &orig_args) { | ||
auto args = orig_args; | ||
args.axis = args.axis >= 0 ? args.axis : Dims - 2; | ||
DALI_ENFORCE(args.axis == Dims - 2, | ||
"Input is expected to be a spectrogram with the last two dimensions being FFT bin index and " | ||
"window index respectively"); | ||
DALI_ENFORCE(args.axis == Dims - 2 || args.axis == Dims - 1, | ||
"Input is expected to be a spectrogram with the last two dimensions being " | ||
"(fftbin_idx, frame_idx), frequency major, or (frame_idx, fftbin_idx), time major."); | ||
auto out_shape = in.shape; | ||
out_shape[args.axis] = args.nfilter; | ||
|
||
|
@@ -147,28 +172,31 @@ KernelRequirements MelFilterBankCpu<T, Dims>::Setup( | |
} | ||
|
||
template <typename T, int Dims> | ||
void MelFilterBankCpu<T, Dims>::Run( | ||
KernelContext &context, | ||
const OutTensorCPU<T, Dims> &out, | ||
const InTensorCPU<T, Dims> &in, | ||
const MelFilterBankArgs &original_args) { | ||
(void) original_args; | ||
void MelFilterBankCpu<T, Dims>::Run(KernelContext &context, const OutTensorCPU<T, Dims> &out, | ||
const InTensorCPU<T, Dims> &in) { | ||
DALI_ENFORCE(impl_ != nullptr); | ||
const auto &args = impl_->Args(); | ||
assert(args.axis == Dims - 2 || args.axis == Dims - 1); | ||
auto in_shape = in.shape; | ||
auto nwin = in_shape[Dims - 1]; | ||
auto in_strides = GetStrides(in_shape); | ||
auto out_shape = out.shape; | ||
auto out_strides = GetStrides(out_shape); | ||
auto for_axis_ndim = out.dim() - 1; // squeeze last dim | ||
ForAxis( | ||
out.data, in.data, out_shape.data(), out_strides.data(), in_shape.data(), in_strides.data(), | ||
args.axis, for_axis_ndim, | ||
[this, nwin]( | ||
T *out_data, const T *in_data, | ||
int64_t out_size, int64_t out_stride, int64_t in_size, int64_t in_stride) { | ||
impl_->Compute(out_data, in_data, nwin); | ||
}); | ||
auto in_strides = GetStrides(in_shape); | ||
|
||
if (args.axis == Dims - 2) { | ||
auto nwin = in_shape[Dims - 1]; | ||
ForAxis(out.data, in.data, out_shape.data(), out_strides.data(), in_shape.data(), | ||
in_strides.data(), Dims - 2, | ||
Dims - 1, // Iterating slices of the two last dimensions | ||
[this, nwin](T *out_data, const T *in_data, int64_t out_size, int64_t out_stride, | ||
int64_t in_size, int64_t in_stride) { | ||
impl_->ComputeFreqMajor(out_data, in_data, nwin); | ||
}); | ||
} else { | ||
int64_t nwin = 1; | ||
for (int d = 0; d < Dims - 1; d++) | ||
nwin *= in_shape[d]; | ||
impl_->ComputeTimeMajor(out.data, in.data, nwin); | ||
} | ||
} | ||
|
||
template class MelFilterBankCpu<float, 2>; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,8 +23,9 @@ DALI_SCHEMA(MelFilterBank) | |
.DocStr(R"code(Converts a spectrogram to a mel spectrogram by applying a bank of | ||
triangular filters. | ||
|
||
Expects an input with at least 2 dimensions where the last two dimensions correspond to | ||
the fft bin index and the window index, respectively. | ||
Expects an input with at least 2 dimensions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be a requirement - I mean, you can just calculate a single Mel spectrum (not necessarily a spectrogram). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The GPU kernel supports any f axis |
||
The frequency ('f') dimension must be present in the layout and should be one of the last two | ||
dimensions in the layout. | ||
)code") | ||
.NumInput(kNumInputs) | ||
.NumOutput(kNumOutputs) | ||
|
@@ -73,7 +74,7 @@ bool MelFilterBank<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc, | |
auto in_shape = input.shape(); | ||
int nsamples = input.size(); | ||
auto nthreads = ws.GetThreadPool().size(); | ||
|
||
args_.axis = input.GetLayout().find('f'); | ||
TYPE_SWITCH(input.type().id(), type2id, T, MEL_FBANK_SUPPORTED_TYPES, ( | ||
VALUE_SWITCH(in_shape.sample_dim(), Dims, MEL_FBANK_SUPPORTED_NDIMS, ( | ||
using MelFilterBankKernel = kernels::audio::MelFilterBankCpu<T, Dims>; | ||
|
@@ -106,7 +107,7 @@ void MelFilterBank<CPUBackend>::RunImpl(workspace_t<CPUBackend> &ws) { | |
[this, &input, &output, i](int thread_id) { | ||
auto in_view = view<const T, Dims>(input[i]); | ||
auto out_view = view<T, Dims>(output[i]); | ||
kmgr_.Run<MelFilterBankKernel>(thread_id, i, ctx_, out_view, in_view, args_); | ||
kmgr_.Run<MelFilterBankKernel>(thread_id, i, ctx_, out_view, in_view); | ||
}, in_shape.tensor_size(i)); | ||
} | ||
), DALI_FAIL(make_string("Unsupported number of dimensions ", in_shape.size()))); // NOLINT | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
|
||
class MelFilterBankPipeline(Pipeline): | ||
def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, freq_high, | ||
normalize, mel_formula, num_threads=1, device_id=0): | ||
normalize, mel_formula, layout='ft', num_threads=1, device_id=0): | ||
super(MelFilterBankPipeline, self).__init__(batch_size, num_threads, device_id) | ||
self.device = device | ||
self.iterator = iterator | ||
|
@@ -40,6 +40,7 @@ def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, | |
freq_high = freq_high, | ||
normalize = normalize, | ||
mel_formula = mel_formula) | ||
self.layout=layout | ||
|
||
def define_graph(self): | ||
self.data = self.inputs() | ||
|
@@ -49,7 +50,7 @@ def define_graph(self): | |
|
||
def iter_setup(self): | ||
data = self.iterator.next() | ||
self.feed_input(self.data, data) | ||
self.feed_input(self.data, data, layout=self.layout) | ||
|
||
def mel_fbank_func(nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula, input_data): | ||
in_shape = input_data.shape | ||
|
@@ -75,7 +76,7 @@ def mel_fbank_func(nfilter, sample_rate, freq_low, freq_high, normalize, mel_for | |
|
||
class MelFilterBankPythonPipeline(Pipeline): | ||
def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, freq_high, | ||
normalize, mel_formula, num_threads=1, device_id=0, func=mel_fbank_func): | ||
normalize, mel_formula, layout='ft', num_threads=1, device_id=0, func=mel_fbank_func): | ||
super(MelFilterBankPythonPipeline, self).__init__( | ||
batch_size, num_threads, device_id, | ||
seed=12345, exec_async=False, exec_pipelined=False) | ||
|
@@ -85,43 +86,57 @@ def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, | |
|
||
function = partial(func, nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula) | ||
self.mel_fbank = ops.PythonFunction(function=function) | ||
self.layout=layout | ||
self.freq_major = layout.find('f') != len(layout) - 1 | ||
if not self.freq_major: | ||
perm = [i for i in range(len(layout))] | ||
f = layout.find('f') | ||
perm[f] = len(layout) - 2 | ||
perm[-2] = f | ||
self.transpose = ops.Transpose(perm=perm) | ||
|
||
def _transposed(self, op): | ||
return lambda x: self.transpose(op(self.transpose(x))) | ||
|
||
def define_graph(self): | ||
self.data = self.inputs() | ||
out = self.mel_fbank(self.data) | ||
mel_fbank = self.mel_fbank if self.freq_major else self._transposed(self.mel_fbank) | ||
out = mel_fbank(self.data) | ||
return out | ||
|
||
def iter_setup(self): | ||
data = self.iterator.next() | ||
self.feed_input(self.data, data) | ||
self.feed_input(self.data, data, layout=self.layout) | ||
|
||
def check_operator_mel_filter_bank_vs_python(device, batch_size, max_shape, | ||
nfilter, sample_rate, freq_low, freq_high, | ||
normalize, mel_formula): | ||
min_shape = [max_shape[0], 1] | ||
normalize, mel_formula, layout): | ||
f_axis = layout.find('f') | ||
min_shape = [1 for _ in max_shape] | ||
min_shape[f_axis] = max_shape[f_axis] | ||
eii1 = RandomlyShapedDataIterator(batch_size, min_shape=min_shape, max_shape=max_shape, dtype=np.float32) | ||
eii2 = RandomlyShapedDataIterator(batch_size, min_shape=min_shape, max_shape=max_shape, dtype=np.float32) | ||
compare_pipelines( | ||
MelFilterBankPipeline(device, batch_size, iter(eii1), | ||
nfilter=nfilter, sample_rate=sample_rate, freq_low=freq_low, freq_high=freq_high, | ||
normalize=normalize, mel_formula=mel_formula), | ||
normalize=normalize, mel_formula=mel_formula, layout=layout), | ||
MelFilterBankPythonPipeline(device, batch_size, iter(eii2), | ||
nfilter=nfilter, sample_rate=sample_rate, freq_low=freq_low, freq_high=freq_high, | ||
normalize=normalize, mel_formula=mel_formula), | ||
normalize=normalize, mel_formula=mel_formula, layout=layout), | ||
batch_size=batch_size, N_iterations=5, eps=1e-03) | ||
|
||
def test_operator_mel_filter_bank_vs_python(): | ||
for device in ['cpu', 'gpu']: | ||
for batch_size in [1, 3]: | ||
for normalize in [True, False]: | ||
for mel_formula in ['htk', 'slaney']: | ||
for nfilter, sample_rate, freq_low, freq_high, shape in \ | ||
[(4, 16000.0, 0.0, 8000.0, (17, 1)), | ||
(128, 16000.0, 0.0, 8000.0, (513, 100)), | ||
(128, 16000.0, 0.0, 8000.0, (10, 513, 100)), | ||
(128, 48000.0, 0.0, 24000.0, (513, 100)), | ||
(128, 48000.0, 4000.0, 24000.0, (513, 100)), | ||
(128, 44100.0, 0.0, 22050.0, (513, 100)), | ||
(128, 44100.0, 1000.0, 22050.0, (513, 100))]: | ||
for nfilter, sample_rate, freq_low, freq_high, shape, layout in \ | ||
[(4, 16000.0, 0.0, 8000.0, (17, 1), 'ft'), | ||
(128, 16000.0, 0.0, 8000.0, (513, 100), 'ft'), | ||
(128, 48000.0, 0.0, 24000.0, (513, 100), 'ft'), | ||
(128, 16000.0, 0.0, 8000.0, (10, 513, 100), 'Ctf'), | ||
(128, 48000.0, 4000.0, 24000.0, (513, 100), 'tf'), | ||
(128, 44100.0, 0.0, 22050.0, (513, 100), 'tf'), | ||
(128, 44100.0, 1000.0, 22050.0, (513, 100), 'tf')]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test with 1D input. |
||
yield check_operator_mel_filter_bank_vs_python, device, batch_size, shape, \ | ||
nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula | ||
nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula, layout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Having strides had some potential for more generic usage.