Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for different layouts in the MelFilterBank GPU Op #2620

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 77 additions & 49 deletions dali/kernels/audio/mel_scale/mel_filter_bank_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,37 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> {
public:
template <typename MelScale>
Impl(MelScale mel_scale, const MelFilterBankArgs &args)
: MelFilterImplBase<T, Dims>(mel_scale, args) {
intervals_.resize(fftbin_size_, -1);
: MelFilterImplBase<T, Dims>(mel_scale, args) {
double mel = mel_low_ + mel_delta_;

int64_t fftbin = fftbin_start_;
double f = fftbin * hz_step_;

int last_interval = args_.nfilter;
for (int64_t interval = 0; interval <= last_interval; interval++, mel += mel_delta_) {
if (interval == last_interval) {
mel = mel_high_;
auto nfilter = args.nfilter;
assert(args.axis == Dims - 1 || args.axis == Dims - 2);
if (args.axis == Dims - 2) {
intervals_.resize(fftbin_size_, -1);
int fftbin = fftbin_start_;
double f = fftbin * hz_step_;
for (int interval = 0; interval < nfilter + 1; interval++, mel += mel_delta_) {
double freq = mel_scale.mel_to_hz(interval == nfilter ? mel_high_ : mel);
for (; fftbin <= fftbin_end_ && f < freq; fftbin++, f = fftbin * hz_step_) {
intervals_[fftbin] = interval;
}
}
double freq = mel_scale.mel_to_hz(mel);
for (; fftbin <= fftbin_end_ && f < freq; fftbin++, f = fftbin * hz_step_) {
intervals_[fftbin] = interval;
} else { // args.axis == Dims - 1
interval_ends_.resize(nfilter + 2);
interval_ends_[0] = fftbin_start_;
interval_ends_[nfilter + 1] = fftbin_end_ + 1;
for (int interval = 1; interval < nfilter + 1; interval++, mel += mel_delta_) {
double freq = mel_scale.mel_to_hz(mel);
interval_ends_[interval] = std::ceil(freq / hz_step_);
}
}
}

void Compute(T* out, const T* in, int64_t nwindows,
int64_t out_stride = -1, int64_t in_stride = -1) {
if (out_stride <= 0)
out_stride = nwindows;

if (in_stride <= 0)
in_stride = nwindows;

void ComputeFreqMajor(T* out, const T* in, int64_t nwindows) {
int nfilter = args_.nfilter;

std::memset(out, 0, sizeof(T) * nfilter * nwindows);
for (int64_t fftbin = fftbin_start_; fftbin <= fftbin_end_; fftbin++) {
auto *in_row_start = in + fftbin * in_stride;
auto *in_row_start = in + fftbin * nwindows;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Having strides had some potential for more generic usage.

auto filter_up = intervals_[fftbin];
auto weight_up = T(1) - weights_down_[fftbin];
auto filter_down = filter_up - 1;
Expand All @@ -84,7 +83,7 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> {
if (filter_down >= 0) {
if (args_.normalize)
weight_down *= norm_factors_[filter_down];
auto *out_row_start = out + filter_down * out_stride;
auto *out_row_start = out + filter_down * nwindows;
for (int t = 0; t < nwindows; t++) {
out_row_start[t] += weight_down * in_row_start[t];
}
Expand All @@ -93,16 +92,43 @@ class MelFilterBankCpu<T, Dims>::Impl: public MelFilterImplBase<T, Dims> {
if (filter_up >= 0 && filter_up < nfilter) {
if (args_.normalize)
weight_up *= norm_factors_[filter_up];
auto *out_row_start = out + filter_up * out_stride;
auto *out_row_start = out + filter_up * nwindows;
for (int t = 0; t < nwindows; t++) {
out_row_start[t] += weight_up * in_row_start[t];
}
}
}
}

void ComputeTimeMajor(T* out, const T* in, int64_t nwindows) {
int nfilter = args_.nfilter;
for (int t = 0; t < nwindows; t++) {
const T *in_row = in + t * fftbin_size_;
for (int m = 0; m < nfilter; m++) {
T val = 0;
int fftbin = interval_ends_[m];
int f1 = interval_ends_[m + 1];
int f2 = interval_ends_[m + 2];
for (; fftbin < f1; ++fftbin) {
auto weight_up = T(1) - weights_down_[fftbin];
if (args_.normalize)
Copy link
Contributor

@mzient mzient Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move it after the loops as val *= norm_factors_[m];

weight_up *= norm_factors_[m];
val += in_row[fftbin] * weight_up;
}
for (; fftbin < f2; ++fftbin) {
auto weight_down = weights_down_[fftbin];
if (args_.normalize)
weight_down *= norm_factors_[m];
val += in_row[fftbin] * weight_down;
}
*out++ = val;
}
}
}

private:
std::vector<int> intervals_;
std::vector<int> interval_ends_;
USE_MEL_FILTER_IMPL_MEMBERS(T, Dims);
};

Expand All @@ -113,15 +139,14 @@ template <typename T, int Dims>
MelFilterBankCpu<T, Dims>::~MelFilterBankCpu() = default;

template <typename T, int Dims>
KernelRequirements MelFilterBankCpu<T, Dims>::Setup(
KernelContext &context,
const InTensorCPU<T, Dims> &in,
const MelFilterBankArgs &original_args) {
auto args = original_args;
KernelRequirements MelFilterBankCpu<T, Dims>::Setup(KernelContext &context,
const InTensorCPU<T, Dims> &in,
const MelFilterBankArgs &orig_args) {
auto args = orig_args;
args.axis = args.axis >= 0 ? args.axis : Dims - 2;
DALI_ENFORCE(args.axis == Dims - 2,
"Input is expected to be a spectrogram with the last two dimensions being FFT bin index and "
"window index respectively");
DALI_ENFORCE(args.axis == Dims - 2 || args.axis == Dims - 1,
"Input is expected to be a spectrogram with the last two dimensions being "
"(fftbin_idx, frame_idx), frequency major, or (frame_idx, fftbin_idx), time major.");
auto out_shape = in.shape;
out_shape[args.axis] = args.nfilter;

Expand All @@ -147,28 +172,31 @@ KernelRequirements MelFilterBankCpu<T, Dims>::Setup(
}

template <typename T, int Dims>
void MelFilterBankCpu<T, Dims>::Run(
KernelContext &context,
const OutTensorCPU<T, Dims> &out,
const InTensorCPU<T, Dims> &in,
const MelFilterBankArgs &original_args) {
(void) original_args;
void MelFilterBankCpu<T, Dims>::Run(KernelContext &context, const OutTensorCPU<T, Dims> &out,
const InTensorCPU<T, Dims> &in) {
DALI_ENFORCE(impl_ != nullptr);
const auto &args = impl_->Args();
assert(args.axis == Dims - 2 || args.axis == Dims - 1);
auto in_shape = in.shape;
auto nwin = in_shape[Dims - 1];
auto in_strides = GetStrides(in_shape);
auto out_shape = out.shape;
auto out_strides = GetStrides(out_shape);
auto for_axis_ndim = out.dim() - 1; // squeeze last dim
ForAxis(
out.data, in.data, out_shape.data(), out_strides.data(), in_shape.data(), in_strides.data(),
args.axis, for_axis_ndim,
[this, nwin](
T *out_data, const T *in_data,
int64_t out_size, int64_t out_stride, int64_t in_size, int64_t in_stride) {
impl_->Compute(out_data, in_data, nwin);
});
auto in_strides = GetStrides(in_shape);

if (args.axis == Dims - 2) {
auto nwin = in_shape[Dims - 1];
ForAxis(out.data, in.data, out_shape.data(), out_strides.data(), in_shape.data(),
in_strides.data(), Dims - 2,
Dims - 1, // Iterating slices of the two last dimensions
[this, nwin](T *out_data, const T *in_data, int64_t out_size, int64_t out_stride,
int64_t in_size, int64_t in_stride) {
impl_->ComputeFreqMajor(out_data, in_data, nwin);
});
} else {
int64_t nwin = 1;
for (int d = 0; d < Dims - 1; d++)
nwin *= in_shape[d];
impl_->ComputeTimeMajor(out.data, in.data, nwin);
}
}

template class MelFilterBankCpu<float, 2>;
Expand Down
3 changes: 1 addition & 2 deletions dali/kernels/audio/mel_scale/mel_filter_bank_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ class DLL_PUBLIC MelFilterBankCpu {

DLL_PUBLIC void Run(KernelContext &context,
const OutTensorCPU<T, Dims> &out,
const InTensorCPU<T, Dims> &in,
const MelFilterBankArgs &args);
const InTensorCPU<T, Dims> &in);

private:
class Impl;
Expand Down
2 changes: 1 addition & 1 deletion dali/kernels/audio/mel_scale/mel_filter_bank_cpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ TEST_P(MelScaleCpuTest, MelScaleCpuTest) {
std::vector<T> out(out_size, 0.0f);
auto out_view = OutTensorCPU<T, Dims>(out.data(), out_shape.to_static<Dims>());

kernel.Run(ctx, out_view, in_view_, args);
kernel.Run(ctx, out_view, in_view_);

LOG_LINE << "in:\n";
print_data(in_view_);
Expand Down
6 changes: 3 additions & 3 deletions dali/kernels/audio/mel_scale/mel_filter_bank_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ template <typename T, int Dims>
class MelFilterBankGpu<T, Dims>::Impl : public MelFilterImplBase<T, Dims> {
public:
template <typename MelScale>
Impl(MelScale mel_scale, const MelFilterBankArgs &args)
: MelFilterImplBase<T, Dims>(mel_scale, args)
, interval_ends_(args.nfilter+2) {
Impl(MelScale mel_scale, const MelFilterBankArgs &args) :
MelFilterImplBase<T, Dims>(mel_scale, args),
interval_ends_(args.nfilter + 2) {
double mel = mel_low_ + mel_delta_;
interval_ends_[0] = fftbin_start_;
interval_ends_[args.nfilter + 1] = fftbin_end_ + 1;
Expand Down
9 changes: 5 additions & 4 deletions dali/operators/audio/mel_scale/mel_filter_bank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ DALI_SCHEMA(MelFilterBank)
.DocStr(R"code(Converts a spectrogram to a mel spectrogram by applying a bank of
triangular filters.

Expects an input with at least 2 dimensions where the last two dimensions correspond to
the fft bin index and the window index, respectively.
Expects an input with at least 2 dimensions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a requirement - I mean, you can just calculate a single Mel spectrum (not necessarily a spectrogram).
It can be worked around by adding artificial dimensions for use by the kernels.
If, on the other hand, there are more than 2 dimensions, all trailing and leading dimensions can be collapsed.
Can we even handle a case when f is somewhere in the middle - think, AfB layout?

Copy link
Collaborator Author

@banasraf banasraf Jan 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPU kernel supports any f axis

The frequency ('f') dimension must be present in the layout and should be one of the last two
dimensions in the layout.
)code")
.NumInput(kNumInputs)
.NumOutput(kNumOutputs)
Expand Down Expand Up @@ -73,7 +74,7 @@ bool MelFilterBank<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
auto in_shape = input.shape();
int nsamples = input.size();
auto nthreads = ws.GetThreadPool().size();

args_.axis = input.GetLayout().find('f');
TYPE_SWITCH(input.type().id(), type2id, T, MEL_FBANK_SUPPORTED_TYPES, (
VALUE_SWITCH(in_shape.sample_dim(), Dims, MEL_FBANK_SUPPORTED_NDIMS, (
using MelFilterBankKernel = kernels::audio::MelFilterBankCpu<T, Dims>;
Expand Down Expand Up @@ -106,7 +107,7 @@ void MelFilterBank<CPUBackend>::RunImpl(workspace_t<CPUBackend> &ws) {
[this, &input, &output, i](int thread_id) {
auto in_view = view<const T, Dims>(input[i]);
auto out_view = view<T, Dims>(output[i]);
kmgr_.Run<MelFilterBankKernel>(thread_id, i, ctx_, out_view, in_view, args_);
kmgr_.Run<MelFilterBankKernel>(thread_id, i, ctx_, out_view, in_view);
}, in_shape.tensor_size(i));
}
), DALI_FAIL(make_string("Unsupported number of dimensions ", in_shape.size()))); // NOLINT
Expand Down
1 change: 1 addition & 0 deletions dali/operators/audio/mel_scale/mel_filter_bank_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bool MelFilterBank<GPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
const workspace_t<GPUBackend> &ws) {
output_desc.resize(kNumOutputs);
const auto &input = ws.InputRef<GPUBackend>(0);
args_.axis = input.GetLayout().find('f');
ctx_.gpu.stream = ws.stream();
const auto &in_shape = input.shape();
TYPE_SWITCH(input.type().id(), type2id, T, MEL_FBANK_SUPPORTED_TYPES, (
Expand Down
51 changes: 33 additions & 18 deletions dali/test/python/test_operator_mel_filter_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class MelFilterBankPipeline(Pipeline):
def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, freq_high,
normalize, mel_formula, num_threads=1, device_id=0):
normalize, mel_formula, layout='ft', num_threads=1, device_id=0):
super(MelFilterBankPipeline, self).__init__(batch_size, num_threads, device_id)
self.device = device
self.iterator = iterator
Expand All @@ -40,6 +40,7 @@ def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low,
freq_high = freq_high,
normalize = normalize,
mel_formula = mel_formula)
self.layout=layout

def define_graph(self):
self.data = self.inputs()
Expand All @@ -49,7 +50,7 @@ def define_graph(self):

def iter_setup(self):
data = self.iterator.next()
self.feed_input(self.data, data)
self.feed_input(self.data, data, layout=self.layout)

def mel_fbank_func(nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula, input_data):
in_shape = input_data.shape
Expand All @@ -75,7 +76,7 @@ def mel_fbank_func(nfilter, sample_rate, freq_low, freq_high, normalize, mel_for

class MelFilterBankPythonPipeline(Pipeline):
def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low, freq_high,
normalize, mel_formula, num_threads=1, device_id=0, func=mel_fbank_func):
normalize, mel_formula, layout='ft', num_threads=1, device_id=0, func=mel_fbank_func):
super(MelFilterBankPythonPipeline, self).__init__(
batch_size, num_threads, device_id,
seed=12345, exec_async=False, exec_pipelined=False)
Expand All @@ -85,43 +86,57 @@ def __init__(self, device, batch_size, iterator, nfilter, sample_rate, freq_low,

function = partial(func, nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula)
self.mel_fbank = ops.PythonFunction(function=function)
self.layout=layout
self.freq_major = layout.find('f') != len(layout) - 1
if not self.freq_major:
perm = [i for i in range(len(layout))]
f = layout.find('f')
perm[f] = len(layout) - 2
perm[-2] = f
self.transpose = ops.Transpose(perm=perm)

def _transposed(self, op):
return lambda x: self.transpose(op(self.transpose(x)))

def define_graph(self):
self.data = self.inputs()
out = self.mel_fbank(self.data)
mel_fbank = self.mel_fbank if self.freq_major else self._transposed(self.mel_fbank)
out = mel_fbank(self.data)
return out

def iter_setup(self):
data = self.iterator.next()
self.feed_input(self.data, data)
self.feed_input(self.data, data, layout=self.layout)

def check_operator_mel_filter_bank_vs_python(device, batch_size, max_shape,
nfilter, sample_rate, freq_low, freq_high,
normalize, mel_formula):
min_shape = [max_shape[0], 1]
normalize, mel_formula, layout):
f_axis = layout.find('f')
min_shape = [1 for _ in max_shape]
min_shape[f_axis] = max_shape[f_axis]
eii1 = RandomlyShapedDataIterator(batch_size, min_shape=min_shape, max_shape=max_shape, dtype=np.float32)
eii2 = RandomlyShapedDataIterator(batch_size, min_shape=min_shape, max_shape=max_shape, dtype=np.float32)
compare_pipelines(
MelFilterBankPipeline(device, batch_size, iter(eii1),
nfilter=nfilter, sample_rate=sample_rate, freq_low=freq_low, freq_high=freq_high,
normalize=normalize, mel_formula=mel_formula),
normalize=normalize, mel_formula=mel_formula, layout=layout),
MelFilterBankPythonPipeline(device, batch_size, iter(eii2),
nfilter=nfilter, sample_rate=sample_rate, freq_low=freq_low, freq_high=freq_high,
normalize=normalize, mel_formula=mel_formula),
normalize=normalize, mel_formula=mel_formula, layout=layout),
batch_size=batch_size, N_iterations=5, eps=1e-03)

def test_operator_mel_filter_bank_vs_python():
for device in ['cpu', 'gpu']:
for batch_size in [1, 3]:
for normalize in [True, False]:
for mel_formula in ['htk', 'slaney']:
for nfilter, sample_rate, freq_low, freq_high, shape in \
[(4, 16000.0, 0.0, 8000.0, (17, 1)),
(128, 16000.0, 0.0, 8000.0, (513, 100)),
(128, 16000.0, 0.0, 8000.0, (10, 513, 100)),
(128, 48000.0, 0.0, 24000.0, (513, 100)),
(128, 48000.0, 4000.0, 24000.0, (513, 100)),
(128, 44100.0, 0.0, 22050.0, (513, 100)),
(128, 44100.0, 1000.0, 22050.0, (513, 100))]:
for nfilter, sample_rate, freq_low, freq_high, shape, layout in \
[(4, 16000.0, 0.0, 8000.0, (17, 1), 'ft'),
(128, 16000.0, 0.0, 8000.0, (513, 100), 'ft'),
(128, 48000.0, 0.0, 24000.0, (513, 100), 'ft'),
(128, 16000.0, 0.0, 8000.0, (10, 513, 100), 'Ctf'),
(128, 48000.0, 4000.0, 24000.0, (513, 100), 'tf'),
(128, 44100.0, 0.0, 22050.0, (513, 100), 'tf'),
(128, 44100.0, 1000.0, 22050.0, (513, 100), 'tf')]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test with 1D input.

yield check_operator_mel_filter_bank_vs_python, device, batch_size, shape, \
nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula
nfilter, sample_rate, freq_low, freq_high, normalize, mel_formula, layout